{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db9125f1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.getcwd())\n",
    "os.chdir('../')\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3684c48a",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfc6853d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MultipleLocator, AutoMinorLocator\n",
    "\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "from matplotlib.patches import Patch\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "from matplotlib import cm\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "import matplotlib.ticker as ticker\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "import wandb\n",
    "api = wandb.Api()\n",
    "from cycler import cycler\n",
    "\n",
    "import cv2\n",
    "\n",
    "# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py\n",
    "\n",
    "Set1 = {\n",
    "    3: [[228,26,28], [55,126,184], [77,175,74]],\n",
    "    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],\n",
    "    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],\n",
    "    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],\n",
    "    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],\n",
    "    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],\n",
    "    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],\n",
    "}\n",
    "\n",
    "Paired = {\n",
    "    3: [(166,206,227), [31,120,180], [178,223,138]],\n",
    "    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],\n",
    "    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],\n",
    "    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],\n",
    "    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],\n",
    "    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],\n",
    "    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],\n",
    "    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],\n",
    "    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],\n",
    "    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]\n",
    "}\n",
    "\n",
    "plt.rcParams['legend.fancybox'] = False\n",
    "plt.rcParams['legend.edgecolor']='1.0'\n",
    "plt.rcParams['legend.framealpha']=0\n",
    "\n",
    "fig_temp = plt.figure()\n",
    "ax_temp=fig_temp.add_subplot()\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42791191",
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone_type_config_dict={'vit_base_patch16_224':[],\n",
    "                           'vit_small_patch16_224':[],\n",
    "                          }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98e54041",
   "metadata": {},
   "source": [
    "# Read alll data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b3533da",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def get_retraining_status(backbone_type, dataset_name, api_dir=\"ch6845/transformer_interpretability_project_retraining\"):\n",
    "    api.flush()\n",
    "    runs = api.runs(api_dir)\n",
    "    result_dict_list=[]\n",
    "    for run in runs:\n",
    "        print(run.config[\"datasets\"], run.config[\"backbone_type\"])\n",
    "        try:\n",
    "            print(run.name)\n",
    "            if run.config[\"datasets\"]!=dataset_name:\n",
    "                print(run.name, \"pass\")\n",
    "                continue\n",
    "            if run.config[\"backbone_type\"]!=backbone_type:\n",
    "                print(run.name, \"pass\")\n",
    "                continue\n",
    "            #print(run.config)\n",
    "            result_dict_list.append({\"name\": run.name,\n",
    "                                     \"accuracy\": run.summary[\"test/accuracy\"] if \"test/accuracy\" in run.summary.keys() else np.nan,\n",
    "                                     \"cohenkappa\": run.summary[\"test/cohenkappa\"] if \"test/cohenkappa\" in run.summary.keys() else np.nan,\n",
    "                                     \"backbone_type\": run.config[\"classifier_backbone_type\"],\n",
    "                                     \"datasets\": run.config[\"datasets\"],\n",
    "                                     \"classifier_enable_pos_embed\": run.config[\"classifier_enable_pos_embed\"] if \"classifier_enable_pos_embed\" in run.config else True,\n",
    "\n",
    "                                     \"explanation_location_train\": run.config[\"explanation_location_train\"],\n",
    "                                     \"explanation_mask_amount_train\": run.config[\"explanation_mask_amount_train\"],\n",
    "                                     \"explanation_mask_ascending_train\": run.config[\"explanation_mask_ascending_train\"],\n",
    "\n",
    "                                     \"explanation_location_val\": run.config[\"explanation_location_val\"],\n",
    "                                     \"explanation_mask_amount_val\": run.config[\"explanation_mask_amount_val\"],\n",
    "                                     \"explanation_mask_ascending_val\": run.config[\"explanation_mask_ascending_val\"],\n",
    "\n",
    "                                     \"explanation_location_test\": run.config[\"explanation_location_test\"],\n",
    "                                     \"explanation_mask_amount_test\": run.config[\"explanation_mask_amount_test\"],\n",
    "                                     \"explanation_mask_ascending_test\": run.config[\"explanation_mask_ascending_test\"]\n",
    "\n",
    "                                    })\n",
    "        except:\n",
    "            print(run.name, 'error')\n",
    "\n",
    "        #print('\\n')\n",
    "    result_df=pd.DataFrame(result_dict_list)\n",
    "    #result_df=result_df[(result_df[\"backbone_type\"]==backbone_type) & (result_df[\"datasets\"]==dataset_name)]\n",
    "    return result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c909dcc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_retraining_status(backbone_type=backbone_type, \n",
    "#                       dataset_name=dataset_name,\n",
    "#                       api_dir=\"ch6845/transformer_interpretability_project_retraining\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b78881c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_loaded_all={}\n",
    "\n",
    "for evaluation_stage in [\"1_classifier_evaluate\",\n",
    "                         \"2_surrogate_evaluate\",\n",
    "                         \"3_explanation_generate\",\n",
    "                         \"4_insert_delete\",\n",
    "                         \"5_sensitivity\",\n",
    "                         \"6_noretraining\",\n",
    "                         \"7_classifiermasked\",\n",
    "                         \"8_elapsedtime\",\n",
    "                         \"9_estimationerror\",\n",
    "                         \"9_retraining\",\n",
    "                         \"10_retrainingnopos\"\n",
    "                        ]:\n",
    "    data_loaded_all.setdefault(evaluation_stage,{})\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        print(evaluation_stage, dataset_name, end=\" -- \")\n",
    "        \n",
    "        data_loaded_all[evaluation_stage].setdefault(dataset_name,{})        \n",
    "        \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            data_loaded_all[evaluation_stage][dataset_name].setdefault(backbone_type, {})\n",
    "            \n",
    "            if evaluation_stage==\"1_classifier_evaluate\":                       \n",
    "                classifier_result_list_path=f'results/1_classifier_evaluate/{dataset_name}/{backbone_type}_test.pickle'\n",
    "                \n",
    "                if os.path.isfile(classifier_result_list_path):\n",
    "                    with open(classifier_result_list_path, \"rb\") as f:\n",
    "                        classifier_result_list=pickle.load(f)\n",
    "                else:\n",
    "                    classifier_result_list={}\n",
    "                print(backbone_type, len(classifier_result_list))\n",
    "                data_loaded_all[evaluation_stage][dataset_name][backbone_type]=classifier_result_list\n",
    "                print(' ')\n",
    "            elif evaluation_stage==\"2_surrogate_evaluate\":\n",
    "                surrogate_result_path=f'results/2_surrogate_evaluate/{dataset_name}/{backbone_type}_test.csv'\n",
    "                if os.path.isfile(surrogate_result_path):\n",
    "                    surrogate_result=pd.read_csv(surrogate_result_path)\n",
    "                else:\n",
    "                    surrogate_result=[]\n",
    "                print(backbone_type, len(surrogate_result))\n",
    "                data_loaded_all[evaluation_stage][dataset_name][backbone_type]=surrogate_result\n",
    "                print(' ')  \n",
    "                    \n",
    "        if evaluation_stage==\"3_explanation_generate\":\n",
    "            print(' ')\n",
    "            explanation_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                explanation_save_dict_backbone={\"attention_rollout\":{},\n",
    "                                                \"attention_last\":{},\n",
    "                                                \"LRP\":{},\n",
    "                                                \"gradcam\":{},\n",
    "                                                \"gradcamgithub\": {},\n",
    "                                                \"vanillapixel\": {},\n",
    "                                                \"vanillaembedding\": {},\n",
    "                                                \"sgpixel\": {},\n",
    "                                                \"sgembedding\": {},\n",
    "                                                \"vargradpixel\": {},\n",
    "                                                \"vargradembedding\": {},               \n",
    "                                                \"igpixel\": {},\n",
    "                                                \"igembedding\": {},\n",
    "                                                \"leaveoneoutclassifier\": {},\n",
    "                                                \"leaveoneoutsurrogate\": {},\n",
    "                                                \"riseclassifier\": {},\n",
    "                                                \"risesurrogate\": {},\n",
    "                                                \"ours\": {},\n",
    "                                                \"kernelshap\": {}\n",
    "                                                }\n",
    "                explanation_save_dict[backbone_type]=explanation_save_dict_backbone            \n",
    "            \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():\n",
    "                    explanation_save_dict_path=f'results/3_explanation_generate/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(explanation_save_dict_path):\n",
    "                        with open(explanation_save_dict_path, 'rb') as f:\n",
    "                            explanation_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        explanation_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(explanation_save_dict_backbone_method)            \n",
    "                    len_loaded=len(explanation_save_dict_loaded)\n",
    "                    explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)\n",
    "                    len_updated=len(explanation_save_dict_backbone_method)                    \n",
    "                    len_unique=len(set([i.replace('l0','').replace('l1lambda','').replace('l2lambda','').replace('l3','').replace('deeper','') for i in explanation_save_dict_backbone_method.keys()]))\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6} unique: {len_unique:6}')                                \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=explanation_save_dict\n",
    "            print(' ')    \n",
    "            \n",
    "        if evaluation_stage==\"4_insert_delete\":\n",
    "            print(' ')\n",
    "            insertdelete_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                insertdelete_save_dict_backbone={\"random\":{},\n",
    "                                                 \"attention_rollout\":{},\n",
    "                                                 \"attention_last\":{},\n",
    "                                                 \"LRP\":{},\n",
    "                                                 \"gradcam\":{},\n",
    "                                                 \"gradcamgithub\": {},\n",
    "                                                 \"vanillapixel\": {},\n",
    "                                                 \"vanillaembedding\": {},\n",
    "                                                 \"sgpixel\": {},\n",
    "                                                 \"sgembedding\": {},\n",
    "                                                 \"vargradpixel\": {},\n",
    "                                                 \"vargradembedding\": {},               \n",
    "                                                 \"igpixel\": {},\n",
    "                                                 \"igembedding\": {},\n",
    "                                                 \"leaveoneoutclassifier\": {},\n",
    "                                                 \"leaveoneoutsurrogate\": {},\n",
    "                                                 \"riseclassifier\": {},\n",
    "                                                 \"risesurrogate\": {},\n",
    "                                                 \"ours\": {},\n",
    "                                                 \"kernelshap\": {}\n",
    "                                                }\n",
    "                insertdelete_save_dict[backbone_type]=insertdelete_save_dict_backbone     \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():\n",
    "                    insertdelete_save_dict_path=f'results/4_insert_delete/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(insertdelete_save_dict_path):\n",
    "                        with open(insertdelete_save_dict_path, 'rb') as f:\n",
    "                            insertdelete_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        insertdelete_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(insertdelete_save_dict_backbone_method)            \n",
    "                    len_loaded=len(insertdelete_save_dict_loaded)\n",
    "                    insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)\n",
    "                    len_updated=len(insertdelete_save_dict_backbone_method)\n",
    "                    len_unique=len(set([i.replace('l0','').replace('l1lambda','').replace('l2lambda','').replace('l3','').replace('deeper','') for i in insertdelete_save_dict_backbone_method.keys()]))\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6} unique: {len_unique:6}')                    \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=insertdelete_save_dict\n",
    "            print(' ')\n",
    "            \n",
    "        if evaluation_stage==\"5_sensitivity\":\n",
    "            print(' ')\n",
    "            sensitivity_save_dit={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                sensitivity_save_dit_backbone={\"attention_rollout\":{},\n",
    "                                               \"attention_last\":{},\n",
    "                                               \"LRP\":{},\n",
    "                                               \"gradcam\":{},\n",
    "                                               \"gradcamgithub\": {},\n",
    "                                               \"vanillapixel\": {},\n",
    "                                               \"vanillaembedding\": {},\n",
    "                                               \"sgpixel\": {},\n",
    "                                               \"sgembedding\": {},\n",
    "                                               \"vargradpixel\": {},\n",
    "                                               \"vargradembedding\": {},               \n",
    "                                               \"igpixel\": {},\n",
    "                                               \"igembedding\": {},\n",
    "                                               \"leaveoneoutclassifier\": {},\n",
    "                                               \"leaveoneoutsurrogate\": {},\n",
    "                                               \"riseclassifier\": {},\n",
    "                                               \"risesurrogate\": {},\n",
    "                                               \"ours\": {},\n",
    "                                               }\n",
    "                sensitivity_save_dit[backbone_type]=sensitivity_save_dit_backbone       \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, sensitivity_save_dit_backbone_method in sensitivity_save_dit[backbone_type].items():\n",
    "                    sensitivity_save_dit_path=f'results/5_sensitivity/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(sensitivity_save_dit_path):\n",
    "                        with open(sensitivity_save_dit_path, 'rb') as f:\n",
    "                            sensitivity_save_dit_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        sensitivity_save_dit_loaded={}\n",
    "\n",
    "                    len_original=len(sensitivity_save_dit_backbone_method)            \n",
    "                    len_loaded=len(sensitivity_save_dit_loaded)\n",
    "                    sensitivity_save_dit_backbone_method.update(sensitivity_save_dit_loaded)\n",
    "                    len_updated=len(sensitivity_save_dit_backbone_method)\n",
    "                    len_unique=len(set([i.replace('l0','').replace('l1lambda','').replace('l2lambda','').replace('l3','').replace('deeper','') for i in sensitivity_save_dit_backbone_method.keys()]))\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6} unique: {len_unique:6}')            \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=sensitivity_save_dit\n",
    "            print(' ')\n",
    "            \n",
    "        if evaluation_stage==\"6_noretraining\":\n",
    "            print(' ')\n",
    "            noretraining_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                noretraining_save_dict_backbone={\"random\":{},\n",
    "                                                 \"attention_rollout\":{},\n",
    "                                                 \"attention_last\":{},\n",
    "                                                 \"LRP\":{},\n",
    "                                                 \"gradcam\":{},\n",
    "                                                 \"gradcamgithub\": {},\n",
    "                                                 \"vanillapixel\": {},\n",
    "                                                 \"vanillaembedding\": {},\n",
    "                                                 \"sgpixel\": {},\n",
    "                                                 \"sgembedding\": {},\n",
    "                                                 \"vargradpixel\": {},\n",
    "                                                 \"vargradembedding\": {},               \n",
    "                                                 \"igpixel\": {},\n",
    "                                                 \"igembedding\": {},\n",
    "                                                 \"leaveoneoutclassifier\": {},\n",
    "                                                 \"leaveoneoutsurrogate\": {},\n",
    "                                                 \"riseclassifier\": {},\n",
    "                                                 \"risesurrogate\": {},\n",
    "                                                 \"ours\": {},\n",
    "                                                }\n",
    "                noretraining_save_dict[backbone_type]=noretraining_save_dict_backbone     \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, noretraining_save_dict_backbone_method in noretraining_save_dict[backbone_type].items():\n",
    "                    noretraining_save_dict_path=f'results/6_noretraining/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(noretraining_save_dict_path):\n",
    "                        with open(noretraining_save_dict_path, 'rb') as f:\n",
    "                            noretraining_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        noretraining_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(noretraining_save_dict_backbone_method)            \n",
    "                    len_loaded=len(noretraining_save_dict_loaded)\n",
    "                    noretraining_save_dict_backbone_method.update(noretraining_save_dict_loaded)\n",
    "                    len_updated=len(noretraining_save_dict_backbone_method)\n",
    "                    len_unique=len(set([i.replace('l0','').replace('l1lambda','').replace('l2lambda','').replace('l3','').replace('deeper','') for i in noretraining_save_dict_backbone_method.keys()]))\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6} unique: {len_unique:6}')                    \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=noretraining_save_dict\n",
    "            print(' ')\n",
    "            \n",
    "        if evaluation_stage==\"7_classifiermasked\":\n",
    "            print(' ')\n",
    "            classifiermasked_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                classifiermasked_save_dict_backbone={\"random\":{},\n",
    "                                                 \"attention_rollout\":{},\n",
    "                                                 \"attention_last\":{},\n",
    "                                                 \"LRP\":{},\n",
    "                                                 \"gradcam\":{},\n",
    "                                                 \"gradcamgithub\": {},\n",
    "                                                 \"vanillapixel\": {},\n",
    "                                                 \"vanillaembedding\": {},\n",
    "                                                 \"sgpixel\": {},\n",
    "                                                 \"sgembedding\": {},\n",
    "                                                 \"vargradpixel\": {},\n",
    "                                                 \"vargradembedding\": {},               \n",
    "                                                 \"igpixel\": {},\n",
    "                                                 \"igembedding\": {},\n",
    "                                                 \"leaveoneoutclassifier\": {},\n",
    "                                                 \"leaveoneoutsurrogate\": {},\n",
    "                                                 \"riseclassifier\": {},\n",
    "                                                 \"risesurrogate\": {},\n",
    "                                                 \"ours\": {},\n",
    "                                                }\n",
    "                classifiermasked_save_dict[backbone_type]=classifiermasked_save_dict_backbone     \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, classifiermasked_save_dict_backbone_method in classifiermasked_save_dict[backbone_type].items():\n",
    "                    classifiermasked_save_dict_path=f'results/7_classifiermasked/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(classifiermasked_save_dict_path):\n",
    "                        with open(classifiermasked_save_dict_path, 'rb') as f:\n",
    "                            classifiermasked_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        classifiermasked_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(classifiermasked_save_dict_backbone_method)            \n",
    "                    len_loaded=len(classifiermasked_save_dict_loaded)\n",
    "                    classifiermasked_save_dict_backbone_method.update(classifiermasked_save_dict_loaded)\n",
    "                    len_updated=len(classifiermasked_save_dict_backbone_method)\n",
    "                    len_unique=len(set([i.replace('l0','').replace('l1lambda','').replace('l2lambda','').replace('l3','').replace('deeper','') for i in classifiermasked_save_dict_backbone_method.keys()]))\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6} unique: {len_unique:6}')                    \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=classifiermasked_save_dict            \n",
    "            print(' ')\n",
    "            \n",
    "        if evaluation_stage==\"8_elapsedtime\":\n",
    "            print(' ')\n",
    "            elapsedtime_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                elapsedtime_save_dict_backbone={\"random\":{},\n",
    "                                                 \"attention_rollout\":{},\n",
    "                                                 \"attention_last\":{},\n",
    "                                                 \"LRP\":{},\n",
    "                                                 \"gradcam\":{},\n",
    "                                                 \"gradcamgithub\": {},\n",
    "                                                 \"vanillapixel\": {},\n",
    "                                                 \"vanillaembedding\": {},\n",
    "                                                 \"sgpixel\": {},\n",
    "                                                 \"sgembedding\": {},\n",
    "                                                 \"vargradpixel\": {},\n",
    "                                                 \"vargradembedding\": {},               \n",
    "                                                 \"igpixel\": {},\n",
    "                                                 \"igembedding\": {},\n",
    "                                                 \"leaveoneoutclassifier\": {},\n",
    "                                                 \"leaveoneoutsurrogate\": {},\n",
    "                                                 \"riseclassifier\": {},\n",
    "                                                 \"risesurrogate\": {},\n",
    "                                                 \"ours\": {},\n",
    "                                                }\n",
    "                elapsedtime_save_dict[backbone_type]=elapsedtime_save_dict_backbone \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():\n",
    "                    elapsedtime_save_dict_path=f'results/8_elapsedtime/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(elapsedtime_save_dict_path):\n",
    "                        with open(elapsedtime_save_dict_path, 'rb') as f:\n",
    "                            elapsedtime_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        elapsedtime_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(elapsedtime_save_dict_backbone_method)            \n",
    "                    len_loaded=len(elapsedtime_save_dict_loaded)\n",
    "                    elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)\n",
    "                    len_updated=len(elapsedtime_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                                    \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=elapsedtime_save_dict\n",
    "            print(' ')\n",
    "            \n",
    "            \n",
    "        if evaluation_stage==\"9_estimationerror\":\n",
    "            print(' ')\n",
    "            elapsedtime_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                elapsedtime_save_dict_backbone={\"kernelshap\":{},\n",
    "                                                \"kernelshapnopair\":{},\n",
    "                                                \"ours\": {},\n",
    "                                                }\n",
    "                elapsedtime_save_dict[backbone_type]=elapsedtime_save_dict_backbone \n",
    "            \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                print(backbone_type,'\\n')\n",
    "                for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():\n",
    "                    elapsedtime_save_dict_path=f'results/9_estimationerror/{dataset_name}/{backbone_type}_{explanation_method}_test.pickle'\n",
    "\n",
    "                    if os.path.isfile(elapsedtime_save_dict_path):\n",
    "                        with open(elapsedtime_save_dict_path, 'rb') as f:\n",
    "                            elapsedtime_save_dict_loaded=pickle.load(f)\n",
    "                    else:\n",
    "                        elapsedtime_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(elapsedtime_save_dict_backbone_method)            \n",
    "                    len_loaded=len(elapsedtime_save_dict_loaded)\n",
    "                    elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)\n",
    "                    len_updated=len(elapsedtime_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                                    \n",
    "            data_loaded_all[evaluation_stage][dataset_name]=elapsedtime_save_dict\n",
    "            print(' ')            \n",
    "            \n",
    "            \n",
    "        if evaluation_stage==\"9_retraining\":\n",
    "            print(' ')\n",
    "            retraining_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                retraining_save_dict[backbone_type]=get_retraining_status(backbone_type=backbone_type, \n",
    "                                                                          dataset_name=dataset_name,\n",
    "                                                                          api_dir=\"ch6845/transformer_interpretability_project_retraining\")\n",
    "                print(len(retraining_save_dict[backbone_type]))\n",
    "            data_loaded_all[evaluation_stage][dataset_name]=retraining_save_dict\n",
    "            print(' ')\n",
    "        if evaluation_stage==\"10_retrainingnopos\":\n",
    "            print(' ')\n",
    "            retrainingnopos_save_dict={}\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                retrainingnopos_save_dict[backbone_type]=get_retraining_status(backbone_type=backbone_type,\n",
    "                                                                               dataset_name=dataset_name,\n",
    "                                                                               api_dir=\"ch6845/transformer_interpretability_project_retraining_nopos\")\n",
    "                print(len(retrainingnopos_save_dict[backbone_type]))\n",
    "            data_loaded_all[evaluation_stage][dataset_name]=retrainingnopos_save_dict            \n",
    "            print(' ')\n",
    "            \n",
    "    print('--------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832207e0",
   "metadata": {},
   "source": [
    "35k 3min\n",
    "\n",
    "64\n",
    "\n",
    "500 batches/min\n",
    "\n",
    "640k 1hour\n",
    "\n",
    "1203k 1hour"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ba3c533",
   "metadata": {},
   "source": [
    "# Formatting tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a4067ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def backbone_type_mapper(backbone_type):\n",
    "    if backbone_type==\"vit_base_patch16_224\":\n",
    "        return \"ViT-Base\"\n",
    "    elif backbone_type==\"vit_small_patch16_224\":\n",
    "        return \"ViT-Small\"\n",
    "    else:\n",
    "        raise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36859ac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mask_location_mapper(mask_location):\n",
    "    if mask_location==\"pre-softmax\":\n",
    "        return \"Masking (Pre-softmax)\"\n",
    "    elif mask_location==\"post-softmax\":\n",
    "        return \"Masking (Post-softmax)\"\n",
    "    elif mask_location==\"zero-input\":\n",
    "        return  \"Zeros at input\"\n",
    "    elif mask_location==\"zero-embedding\":\n",
    "        return  \"Zeros at embedding\"  \n",
    "    elif mask_location==\"random-sampling\":\n",
    "        return  \"Random sampling\"      \n",
    "    else:\n",
    "        raise ValueError(mask_location)\n",
    "#for mask_location_parameter, mask_location_model in mask_location_parameter_model_list\n",
    "\n",
    "\"\"\"\n",
    "[(\"pre-softmax\", \"original\"),\n",
    "                                        (\"pre-softmax\", \"pre-softmax\"),\n",
    "                                        (\"post-softmax\", \"original\"),\n",
    "                                        (\"random-sampling\", \"original\"),\n",
    "                                        (\"zero-input\", \"original\"),\n",
    "                                        (\"zero-input\", \"zero-input\"),\n",
    "                                        (\"zero-embedding\", \"original\")\n",
    "                                        ]\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "def mask_location_parameter_model_mapper(mask_location_parameter, mask_location_model):\n",
    "    if (mask_location_parameter, mask_location_model)==(\"pre-softmax\", \"original\"):\n",
    "        return \"Masking\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"pre-softmax\", \"pre-softmax\"):\n",
    "        return \"Masking (Fine-tuned)\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"post-softmax\", \"original\"):\n",
    "        return \"Masking (Post-softmax)\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"random-sampling\", \"original\"):\n",
    "        return \"Random imputation\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"zero-input\", \"zero-input\"):\n",
    "        return \"Mask token (Fine-tuned)\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"zero-input\", \"original\"):\n",
    "        return \"Zeros at input\"\n",
    "    elif (mask_location_parameter, mask_location_model)==(\"zero-embedding\", \"original\"):\n",
    "        return \"Zeros at embedding\"\n",
    "    else:\n",
    "        raise ValueError(mask_location_parameter, mask_location_model)\n",
    "    \n",
    "#     if mask_location==\"pre-softmax\":\n",
    "#         return \"Masking (Pre-softmax)\"\n",
    "#     elif mask_location==\"post-softmax\":\n",
    "#         return \"Masking (Post-softmax)\"\n",
    "#     elif mask_location==\"zero-input\":\n",
    "#         return  \"Zeros at input\"\n",
    "#     elif mask_location==\"zero-embedding\":\n",
    "#         return  \"Zeros at embedding\"  \n",
    "#     elif mask_location==\"random-sampling\":\n",
    "#         return  \"Random sampling\"      \n",
    "#     else:\n",
    "#         raise ValueError(mask_location)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30e74929",
   "metadata": {},
   "outputs": [],
   "source": [
    "def insert_delete_mapper(insert_delete, verbose=True):\n",
    "    if verbose:\n",
    "        if insert_delete==\"insert\":\n",
    "            return \"Insertion\"\n",
    "        elif insert_delete==\"delete\":\n",
    "            return \"Deletion\"\n",
    "        else:\n",
    "            raise\n",
    "    else:\n",
    "        if insert_delete==\"insert\":\n",
    "            return \"Ins.\"\n",
    "        elif insert_delete==\"delete\":\n",
    "            return \"Del.\"\n",
    "        else:\n",
    "            raise            \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d9d07a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_non_target_mapper(target_non_target):\n",
    "    if target_non_target==\"target\":\n",
    "        return \"Target\"\n",
    "    elif target_non_target==\"non-target\":\n",
    "        return \"Non-target\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bba08c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explanation_method_mapper(explanation_method, subset_mode=\"main\"):\n",
    "    if subset_mode==\"main\":\n",
    "        return {\"attention_last\": \"Attention last\",\n",
    "                \"attention_rollout\": \"Attention rollout\",\n",
    "                \"LRP\": \"LRP\",\n",
    "                \"gradcam\": \"GradCAM (Attn)\",\n",
    "                \"gradcamgithub\": \"GradCAM\",\n",
    "                \"igembedding\": \"IntGrad\",\n",
    "                \"vanillaembedding\": \"Vanilla\",\n",
    "                \"sgembedding\": \"SmoothGrad\",\n",
    "                \"vargradembedding\": \"VarGrad\",\n",
    "                \"leaveoneoutclassifier\": \"Leave-one-out\",\n",
    "                \"riseclassifier\": \"RISE\",\n",
    "                \"ours\": \"ViT Shapley\",\n",
    "                \"kernelshap\": \"KernelSHAP\",\n",
    "                \"random\":\"Random\"\n",
    "               }[explanation_method]        \n",
    "    elif subset_mode==\"supple\":\n",
    "        return {\"attention_last\": \"Attention last\",\n",
    "                \"attention_rollout\": \"Attention rollout\",\n",
    "                \"LRP\": \"LRP\",\n",
    "                \"gradcam\": \"GradCAM (Attn)\",\n",
    "                \"gradcamgithub\": \"GradCAM (LN)\",\n",
    "                \"igpixel\": \"IntGrad (Pixel)\",\n",
    "                \"igembedding\": \"IntGrad (Embed.)\",\n",
    "                \"vanillapixel\": \"Vanilla (Pixel)\",\n",
    "                \"vanillaembedding\": \"Vanilla (Embed.)\",\n",
    "                \"sgpixel\": \"SmoothGrad (Pixel)\",\n",
    "                \"sgembedding\": \"SmoothGrad (Embed.)\",\n",
    "                \"vargradpixel\": \"VarGrad (Pixel)\",\n",
    "                \"vargradembedding\": \"VarGrad (Embed.)\",\n",
    "                \"leaveoneoutclassifier\": \"Leave-one-out\",\n",
    "                \"riseclassifier\": \"RISE\",\n",
    "                \"ours\": \"ViT Shapley\",\n",
    "                \"kernelshap\": \"KernelSHAP\",\n",
    "                \"random\":\"Random\"                \n",
    "               }[explanation_method]\n",
    "    elif subset_mode==\"qualitative\":\n",
    "        return {\"attention_last\": \"Attention last\",\n",
    "                \"attention_rollout\": \"Attention rollout\",\n",
    "                \"LRP\": \"LRP\",\n",
    "                \"gradcam\": \"GradCAM (Attn)\",\n",
    "                \"gradcamgithub\": \"GradCAM (LN)\",\n",
    "                \"igembedding\": \"IntGrad\",\n",
    "                \"vanillaembedding\": \"Vanilla\",\n",
    "                \"sgembedding\": \"SmoothGrad\",\n",
    "                \"vargradembedding\": \"VarGrad\",\n",
    "                \"leaveoneoutclassifier\": \"Leave-one-out\",\n",
    "                \"riseclassifier\": \"RISE\",\n",
    "                \"ours\": \"ViT Shapley (Ours)\",\n",
    "                \"random\":\"Random\"\n",
    "               }[explanation_method]   \n",
    "    else:\n",
    "        raise\n",
    "\n",
    "        \n",
    "        \n",
    "explanation_method_main=[[\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcamgithub\", \n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgembedding\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"ours\"]]      \n",
    "\n",
    "explanation_method_main_random=[[\"attention_last\", \n",
    "                                  \"attention_rollout\"],\n",
    "                                [\"gradcamgithub\", \n",
    "                                 \"igembedding\",                                                              \n",
    "                                 \"vanillaembedding\",                          \n",
    "                                 \"sgembedding\",\n",
    "                                 \"vargradembedding\",\n",
    "                                 \"LRP\"],\n",
    "                                [\"leaveoneoutclassifier\",\n",
    "                                 \"riseclassifier\", \n",
    "                                 \"ours\"],\n",
    "                                [\"random\"]\n",
    "]\n",
    "\n",
    "        \n",
    "explanation_method_main_kernelshap=[[\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcamgithub\", \n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgembedding\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"kernelshap\",                          \n",
    "                          \"ours\"]]      \n",
    "\n",
    "explanation_method_main_random_kernelshap=[[\"attention_last\", \n",
    "                                  \"attention_rollout\"],\n",
    "                                [\"gradcamgithub\", \n",
    "                                 \"igembedding\",                                                              \n",
    "                                 \"vanillaembedding\",                          \n",
    "                                 \"sgembedding\",\n",
    "                                 \"vargradembedding\",\n",
    "                                 \"LRP\"],\n",
    "                                [\"leaveoneoutclassifier\",\n",
    "                                 \"riseclassifier\", \n",
    "                                 \"kernelshap\",\n",
    "                                 \"ours\"],\n",
    "                                [\"random\"]\n",
    "]\n",
    "\n",
    "explanation_method_qualitative_main=[[\"attention_rollout\"],\n",
    "                                     [\"LRP\"],\n",
    "                                     [\"ours\"]] \n",
    "        \n",
    "explanation_method_supple=[[\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcamgithub\", \n",
    "                          \"gradcam\",\n",
    "                          \"igpixel\",\n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillapixel\", \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgpixel\",\n",
    "                          \"sgembedding\",\n",
    "                          \"vargradpixel\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"ours\"]]\n",
    "\n",
    "explanation_method_supple_random=[\n",
    "                          [\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcam\", \n",
    "                          \"gradcamgithub\",\n",
    "                          \"igpixel\",\n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillapixel\", \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgpixel\",\n",
    "                          \"sgembedding\",\n",
    "                          \"vargradpixel\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"ours\"],\n",
    "                        [\"random\"]\n",
    "]\n",
    "\n",
    "explanation_method_supple_kernelshap=[[\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcamgithub\", \n",
    "                          \"gradcam\",\n",
    "                          \"igpixel\",\n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillapixel\", \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgpixel\",\n",
    "                          \"sgembedding\",\n",
    "                          \"vargradpixel\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"kernelshap\",                          \n",
    "                          \"ours\"]]\n",
    "\n",
    "explanation_method_supple_random_kernelshap=[\n",
    "                          [\"attention_last\", \n",
    "                          \"attention_rollout\"],\n",
    "                         [\"gradcam\", \n",
    "                          \"gradcamgithub\",\n",
    "                          \"igpixel\",\n",
    "                          \"igembedding\",                                                              \n",
    "                          \"vanillapixel\", \n",
    "                          \"vanillaembedding\",                          \n",
    "                          \"sgpixel\",\n",
    "                          \"sgembedding\",\n",
    "                          \"vargradpixel\",\n",
    "                          \"vargradembedding\",\n",
    "                          \"LRP\"],\n",
    "                         [\"leaveoneoutclassifier\",\n",
    "                          \"riseclassifier\", \n",
    "                          \"kernelshap\",\n",
    "                          \"ours\"],\n",
    "                        [\"random\"]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c4ce608",
   "metadata": {},
   "outputs": [],
   "source": [
    "def adapt_path(path_original, dict_keys):\n",
    "    path_list = ['l0.cs.hostname', 'l1lambda.cs.hostname', 'l2lambda.cs.hostname',\n",
    "                 'l3.cs.hostname', 'deeper.cs.hostname', 'sync']\n",
    "\n",
    "    dict_keys=list(dict_keys)\n",
    "\n",
    "\n",
    "    for path1 in path_list:\n",
    "        if path1 in path_original:\n",
    "            for path2 in path_list:\n",
    "                path_replaced=path_original.replace(path1, path2)\n",
    "                if path_replaced in dict_keys:\n",
    "                    return path_replaced\n",
    "    return path_original\n",
    "    #raise ValueError(f\"not found {path_original}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c628e87",
   "metadata": {},
   "source": [
    "# KL divergence plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aac7c172",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def draw_kl_divergence_accuracy(result_df_dict):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[10][1], \n",
    "                                                                               Paired[10][3],\n",
    "                                                                               Paired[10][5],\n",
    "                                                                               Paired[10][7],\n",
    "                                                                               Paired[10][9]\n",
    "                                                                              ]])\n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "    result_df_ImageNette[\"num_excluded\"]=196-result_df_ImageNette[\"num_mask\"]\n",
    "    result_df_MURA[\"num_excluded\"]=196-result_df_MURA[\"num_mask\"]\n",
    "\n",
    "    mask_location_parameter_model_list=[(\"pre-softmax\", \"original\"),\n",
    "                                        (\"pre-softmax\", \"pre-softmax\"),\n",
    "                                        (\"zero-input\", \"original\"),                                        \n",
    "                                        (\"zero-input\", \"zero-input\"),\n",
    "                                        (\"post-softmax\", \"original\"),\n",
    "                                        (\"zero-embedding\", \"original\"),                                        \n",
    "                                        (\"random-sampling\", \"original\"),\n",
    "                                        \n",
    "]\n",
    "\n",
    "    \n",
    "#     fig = plt.figure(constrained_layout=True, figsize=(18, 10))\n",
    "#     subfigs = fig.subfigures(1, 2)\n",
    "    \n",
    "#     ax_ImageNette = subfigs[0].subplots(2, 1)\n",
    "#     ax_MURA = subfigs[1].subplots(2, 1)\n",
    "#     axd={\"kl_divergence_ImageNette\": ax_ImageNette[0],\n",
    "#          \"kl_divergence_MURA\": ax_MURA[0],\n",
    "#          \"accuracy_ImageNette\": ax_ImageNette[1],\n",
    "#          \"accuracy_MURA\": ax_MURA[1],         \n",
    "#         }\n",
    "\n",
    "    fig, ax=plt.subplots(3, 2, gridspec_kw={\"height_ratios\":[1,0.0001,1]}, figsize=(22, 12))\n",
    "    #ax=fig.add_subplots(2,2)\n",
    "    \n",
    "    axd={\"kl_divergence_ImageNette\": ax[0][0],\n",
    "         \"kl_divergence_MURA\": ax[0][1],\n",
    "         \"accuracy_ImageNette\": ax[2][0],\n",
    "         \"accuracy_MURA\": ax[2][1],\n",
    "         \"empty1\": ax[1][0],\n",
    "         \"empty2\": ax[1][1],\n",
    "        }\n",
    "    \n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)     \n",
    "    \n",
    "\n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        plot_key=f\"kl_divergence_{dataset_name}\"\n",
    "        for mask_location_parameter, mask_location_model in mask_location_parameter_model_list:\n",
    "\n",
    "            to_plot=result_df_dict[dataset_name][(result_df_dict[dataset_name][\"mask_location_parameter\"]==mask_location_parameter)&\n",
    "                                                 (result_df_dict[dataset_name][\"mask_location_model\"]==mask_location_model)\n",
    "                                                ].groupby('num_excluded').mean()\n",
    "   \n",
    "            axd[plot_key].plot(to_plot['kl_divergence'], \n",
    "                                      linestyle='--' if mask_location_model!=\"original\" else '-',\n",
    "                                      c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][[\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"].index(mask_location_parameter)],\n",
    "                                      linewidth=3 if mask_location_model!=\"original\" else 3)        \n",
    "\n",
    "        axd[plot_key].spines['right'].set_visible(False)\n",
    "        axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "        axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))\n",
    "        axd[plot_key].xaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.25))\n",
    "        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        #axd[plot_key].set_xlabel('Number of included patches')\n",
    "        \n",
    "        \n",
    "        axd[plot_key].set_ylabel('KL divergence')\n",
    "\n",
    "        for axis in ['top','bottom','left','right']:\n",
    "            axd[plot_key].spines[axis].set_linewidth(2)   \n",
    "            \n",
    "        axd[plot_key].set_title(dataset_name, pad=10)\n",
    "\n",
    "        axd[plot_key].set_ylim(0, 3.125)\n",
    "        axd[plot_key].set_xlim(-2, 200)\n",
    "        \n",
    "        \n",
    "        plot_key=f\"accuracy_{dataset_name}\"\n",
    "        for mask_location_parameter, mask_location_model in mask_location_parameter_model_list:\n",
    "\n",
    "            to_plot=result_df_dict[dataset_name][(result_df_dict[dataset_name][\"mask_location_parameter\"]==mask_location_parameter)&\n",
    "                                                 (result_df_dict[dataset_name][\"mask_location_model\"]==mask_location_model)\n",
    "                             ].groupby('num_excluded').mean()\n",
    "\n",
    "            axd[plot_key].plot(to_plot['accuracy'], \n",
    "                                 linestyle='--' if mask_location_model!=\"original\" else '-',\n",
    "                                 c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][[\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"].index(mask_location_parameter)],\n",
    "                                 linewidth=3 if mask_location_model!=\"original\" else 3)  \n",
    "\n",
    "        axd[plot_key].spines['right'].set_visible(False)\n",
    "        axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "        axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))\n",
    "        axd[plot_key].xaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        axd[plot_key].set_xlabel('# of Deleted Patches')\n",
    "        \n",
    "     \n",
    "        axd[plot_key].set_ylabel('Accuracy')\n",
    "\n",
    "\n",
    "        for axis in ['top','bottom','left','right']:\n",
    "            axd[plot_key].spines[axis].set_linewidth(2)   \n",
    "\n",
    "        axd[plot_key].set_title(dataset_name, pad=10)\n",
    "        \n",
    "        axd[plot_key].set_ylim(0, 1.05)\n",
    "        axd[plot_key].set_xlim(-2, 200)\n",
    "\n",
    "        legend_elements=[Line2D([0], [0], \n",
    "                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][[\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"].index(mask_location_parameter)],\n",
    "                                linewidth=5,\n",
    "                                linestyle='--' if mask_location_model!=\"original\" else '-',\n",
    "                                label=mask_location_parameter_model_mapper(mask_location_parameter, mask_location_model)) \n",
    "                         for mask_location_parameter, mask_location_model in mask_location_parameter_model_list]\n",
    "\n",
    "    fig.legend(handles=legend_elements, \n",
    "                                ncol=5,\n",
    "                                handletextpad=0.6, \n",
    "                                columnspacing=1,\n",
    "                               loc='lower center', bbox_to_anchor=(0.5, -0.02))  \n",
    "    \n",
    "    #fig.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1081289",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):\n",
    "    print(backbone_type)\n",
    "    \n",
    "    result_df_ImageNette = data_loaded_all[\"2_surrogate_evaluate\"][\"ImageNette\"][backbone_type]\n",
    "    result_df_MURA = data_loaded_all[\"2_surrogate_evaluate\"][\"MURA\"][backbone_type]\n",
    "    \n",
    "    fig = draw_kl_divergence_accuracy(result_df_dict={\"ImageNette\": result_df_ImageNette, \n",
    "                                             \"MURA\": result_df_MURA}) \n",
    "\n",
    "    #fig=draw_kl_divergence(result_df)\n",
    "\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_accuracy_{backbone_type}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_accuracy_{backbone_type}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_accuracy_{backbone_type}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_accuracy_{backbone_type}.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2602a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_kl_divergence(result_df_dict):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[10][1], \n",
    "                                                                               Paired[10][3],\n",
    "                                                                               Paired[10][5],\n",
    "                                                                               Paired[10][7],\n",
    "                                                                               Paired[10][9]\n",
    "                                                                              ]])\n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "    result_df_ImageNette[\"num_excluded\"]=196-result_df_ImageNette[\"num_mask\"]\n",
    "    result_df_MURA[\"num_excluded\"]=196-result_df_MURA[\"num_mask\"]\n",
    "\n",
    "    mask_location_parameter_model_list=[(\"pre-softmax\", \"original\"),\n",
    "                                        (\"pre-softmax\", \"pre-softmax\"),\n",
    "                                        (\"zero-input\", \"original\"),                                        \n",
    "                                        (\"zero-input\", \"zero-input\"),\n",
    "                                        (\"post-softmax\", \"original\"),\n",
    "                                        (\"zero-embedding\", \"original\"),                                        \n",
    "                                        (\"random-sampling\", \"original\"),\n",
    "                                        \n",
    "]\n",
    "\n",
    "    \n",
    "#     fig = plt.figure(constrained_layout=True, figsize=(18, 10))\n",
    "#     subfigs = fig.subfigures(1, 2)\n",
    "    \n",
    "#     ax_ImageNette = subfigs[0].subplots(2, 1)\n",
    "#     ax_MURA = subfigs[1].subplots(2, 1)\n",
    "#     axd={\"kl_divergence_ImageNette\": ax_ImageNette[0],\n",
    "#          \"kl_divergence_MURA\": ax_MURA[0],\n",
    "#          \"accuracy_ImageNette\": ax_ImageNette[1],\n",
    "#          \"accuracy_MURA\": ax_MURA[1],         \n",
    "#         }\n",
    "\n",
    "    fig, ax=plt.subplots(2, 2, gridspec_kw={\"height_ratios\":[1,0.0001]}, figsize=(22, 6))\n",
    "    #ax=fig.add_subplots(2,2)\n",
    "    \n",
    "    axd={\"kl_divergence_ImageNette\": ax[0][0],\n",
    "         \"kl_divergence_MURA\": ax[0][1],\n",
    "         \"empty1\": ax[1][0],\n",
    "         \"empty2\": ax[1][1],\n",
    "        }\n",
    "    \n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)     \n",
    "    \n",
    "\n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        plot_key=f\"kl_divergence_{dataset_name}\"\n",
    "        for mask_location_parameter, mask_location_model in mask_location_parameter_model_list:\n",
    "\n",
    "            to_plot=result_df_dict[dataset_name][(result_df_dict[dataset_name][\"mask_location_parameter\"]==mask_location_parameter)&\n",
    "                                                 (result_df_dict[dataset_name][\"mask_location_model\"]==mask_location_model)\n",
    "                                                ].groupby('num_excluded').mean()\n",
    "   \n",
    "            axd[plot_key].plot(to_plot['kl_divergence'], \n",
    "                                      linestyle='--' if mask_location_model!=\"original\" else '-',\n",
    "                                      c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][[\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"].index(mask_location_parameter)],\n",
    "                                      linewidth=3 if mask_location_model!=\"original\" else 3)        \n",
    "\n",
    "        axd[plot_key].spines['right'].set_visible(False)\n",
    "        axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "        axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "        axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))\n",
    "        axd[plot_key].xaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.25))\n",
    "        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.4, alpha=0.6)\n",
    "        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.4, alpha=0.2)\n",
    "\n",
    "        #axd[plot_key].set_xlabel('Number of included patches')\n",
    "        \n",
    "        axd[plot_key].set_xlabel('# of Deleted Patches', labelpad=10)\n",
    "        \n",
    "        axd[plot_key].set_ylabel('KL divergence')\n",
    "\n",
    "        for axis in ['top','bottom','left','right']:\n",
    "            axd[plot_key].spines[axis].set_linewidth(2)   \n",
    "            \n",
    "        axd[plot_key].set_title(dataset_name, pad=10)\n",
    "\n",
    "        axd[plot_key].set_ylim(0, 3.125)\n",
    "        axd[plot_key].set_xlim(-2, 200)\n",
    "        \n",
    "        legend_elements=[Line2D([0], [0], \n",
    "                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][[\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"].index(mask_location_parameter)],\n",
    "                                linewidth=5,\n",
    "                                linestyle='--' if mask_location_model!=\"original\" else '-',\n",
    "                                label=mask_location_parameter_model_mapper(mask_location_parameter, mask_location_model)) \n",
    "                         for mask_location_parameter, mask_location_model in mask_location_parameter_model_list]        \n",
    "        \n",
    "    fig.legend(handles=legend_elements, \n",
    "                                ncol=5,\n",
    "                                handletextpad=0.6, \n",
    "                                columnspacing=1,\n",
    "                               loc='lower center', bbox_to_anchor=(0.5, -0.1)).set_zorder(100) \n",
    "    \n",
    "    #fig.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd0d4f99",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):\n",
    "    print(backbone_type)\n",
    "    \n",
    "    result_df_ImageNette = data_loaded_all[\"2_surrogate_evaluate\"][\"ImageNette\"][backbone_type]\n",
    "    result_df_MURA = data_loaded_all[\"2_surrogate_evaluate\"][\"MURA\"][backbone_type]\n",
    "    \n",
    "    fig = draw_kl_divergence(result_df_dict={\"ImageNette\": result_df_ImageNette, \n",
    "                                             \"MURA\": result_df_MURA}) \n",
    "\n",
    "    #fig=draw_kl_divergence(result_df)\n",
    "\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_{backbone_type}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_{backbone_type}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_{backbone_type}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/surrogate_evaluate_kl_divergence_{backbone_type}.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba30e889",
   "metadata": {},
   "source": [
    "# Process ROC, AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab2acc55",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loaded_all.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5893095e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "roc_auc_result_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    roc_auc_result_dict.setdefault(dataset_name, {})\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        roc_auc_result_dict[dataset_name].setdefault(backbone_type, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        \n",
    "        roc_auc_dict_list=[]\n",
    "        for explanation_method, insertdelete_save_dict in data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type].items():\n",
    "            for path, insertdelete_dict in insertdelete_save_dict.items():\n",
    "                classifier_prob_data=data_loaded_all['1_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                classifier_prob=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['prob']\n",
    "                label=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['label']\n",
    "                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    assert classifier_prob[0]>=0 and classifier_prob[0]<=1\n",
    "                    for metric_mode in ['insert', 'delete']:\n",
    "                        roc=insertdelete_dict[metric_mode]\n",
    "                        if explanation_method==\"random\":\n",
    "                            assert roc.shape==(10, 1, 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,\n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc.mean(axis=0)[0] if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'target_auc': roc.mean(axis=0)[0].mean(axis=-1) if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'non-target_roc': roc.mean(axis=0)[0] if classifier_prob[0]<0.5 else None,\n",
    "                                                      'non-target_auc': roc.mean(axis=0)[0].mean(axis=-1) if classifier_prob[0]<0.5 else None,\n",
    "                                                      #'accuracy': ((roc>0.5).astype(int)==label).mean(axis=0)[0]\n",
    "                                                     })                            \n",
    "\n",
    "                        elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert roc.shape==(1, 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,                                      \n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc[0] if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'target_auc': roc[0].mean(axis=-1) if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'non-target_roc': roc[0] if classifier_prob[0]<0.5 else None,\n",
    "                                                      'non-target_auc': roc[0].mean(axis=-1) if classifier_prob[0]<0.5 else None,\n",
    "                                                      #'accuracy': ((roc>0.5).astype(int)==label)[0].astype(float)\n",
    "                                                 })\n",
    "                        else:\n",
    "                            assert roc.shape==(1, 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,                                      \n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc[0] if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'target_auc': roc[0].mean(axis=-1) if classifier_prob[0]>=0.5 else None,\n",
    "                                                      'non-target_roc': roc[0] if classifier_prob[0]<0.5 else None,\n",
    "                                                      'non-target_auc': roc[0].mean(axis=-1) if classifier_prob[0]<0.5 else None,\n",
    "                                                      #'accuracy': ((roc>0.5).astype(int)==label)[0].astype(float)\n",
    "                                                 })\n",
    "                            #ipdb.set_trace()\n",
    "                        \n",
    "                else: # ImageNette\n",
    "                    classifier_prob_argmax=np.argmax(classifier_prob)\n",
    "                    assert classifier_prob_argmax>=0 and classifier_prob_argmax<len(classifier_prob)\n",
    "                    for metric_mode in ['insert', 'delete']:\n",
    "                        roc=insertdelete_dict[metric_mode]\n",
    "                        if explanation_method==\"random\":\n",
    "                            assert roc.shape==(10, len(classifier_prob), 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,\n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc.mean(axis=0)[classifier_prob_argmax],\n",
    "                                                      'target_auc': roc.mean(axis=0)[classifier_prob_argmax].mean(axis=-1),\n",
    "                                                      'non-target_roc': roc.mean(axis=0)[np.arange(len(roc.mean(axis=0)))!=classifier_prob_argmax].mean(axis=0),\n",
    "                                                      'non-target_auc': roc.mean(axis=0)[np.arange(len(roc.mean(axis=0)))!=classifier_prob_argmax].mean(axis=0).mean(axis=-1),\n",
    "                                                      #'accuracy': (roc.argmax(axis=1)==label).mean(axis=0)\n",
    "                                                     })                             \n",
    "\n",
    "                        elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,                                      \n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc[classifier_prob_argmax],\n",
    "                                                      'target_auc': roc[classifier_prob_argmax].mean(axis=-1),\n",
    "                                                      'non-target_roc': None,\n",
    "                                                      'non-target_auc': None,\n",
    "                                                      #'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                 })\n",
    "                        else:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            roc_auc_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'metric_mode': metric_mode,                                      \n",
    "                                                      'path': path,\n",
    "                                                      'target_roc': roc[classifier_prob_argmax],\n",
    "                                                      'target_auc': roc[classifier_prob_argmax].mean(axis=-1),\n",
    "                                                      'non-target_roc': roc[np.arange(len(roc))!=classifier_prob_argmax].mean(axis=0),\n",
    "                                                      'non-target_auc': roc[np.arange(len(roc))!=classifier_prob_argmax].mean(axis=0).mean(axis=-1),\n",
    "                                                      #'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                 })                \n",
    "        \n",
    "        roc_auc_result_dict[dataset_name][backbone_type]=roc_auc_dict_list #print(pd.DataFrame(roc_auc_dict_list).groupby(['metric_mode', 'explanation_method']).mean().sort_values('metric_mode', ascending=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca6cc426",
   "metadata": {},
   "source": [
    "# Process sensitivity-n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3764e416",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sensitivity_result_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    sensitivity_result_dict.setdefault(dataset_name, {})\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        sensitivity_result_dict[dataset_name].setdefault(backbone_type, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        \n",
    "        sensitivity_dict_list=[]\n",
    "        for explanation_method, sensitivity_save_dict in data_loaded_all[\"5_sensitivity\"][dataset_name][backbone_type].items():\n",
    "            for path, sensitivity in sensitivity_save_dict.items():        \n",
    "                classifier_prob_data=data_loaded_all['1_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                classifier_prob=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['prob']                \n",
    "                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    assert classifier_prob[0]>=0 and classifier_prob[0]<=1\n",
    "                    for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "                        if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert sensitivity[num_included_players].shape==(1,)\n",
    "                            sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                          'path': path,\n",
    "                                                          'num_included_players': num_included_players,\n",
    "                                                          'target': sensitivity[num_included_players][0] if classifier_prob[0]>=0.5 else None,\n",
    "                                                          'non-target': sensitivity[num_included_players][0] if classifier_prob[0]<0.5 else None,\n",
    "                                                         })\n",
    "\n",
    "                        else:\n",
    "                            assert sensitivity[num_included_players].shape==(1,)\n",
    "                            sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                          'path': path,\n",
    "                                                          'num_included_players': num_included_players,                                                      \n",
    "                                                          'target': sensitivity[num_included_players][0] if classifier_prob[0]>=0.5 else None,\n",
    "                                                          'non-target': sensitivity[num_included_players][0] if classifier_prob[0]<0.5 else None,\n",
    "                                                         })                       \n",
    "                    \n",
    "                else:\n",
    "                    classifier_prob_argmax=np.argmax(classifier_prob)\n",
    "                    assert classifier_prob_argmax>=0 and classifier_prob_argmax<=len(classifier_prob)\n",
    "                    for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "                        if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert sensitivity[num_included_players].shape==(len(classifier_prob),)\n",
    "                            sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                          'path': path,\n",
    "                                                          'num_included_players': num_included_players,\n",
    "                                                          'target': sensitivity[num_included_players][classifier_prob_argmax],\n",
    "                                                          'non-target': None,\n",
    "                                                         })\n",
    "\n",
    "                        else:\n",
    "                            assert sensitivity[num_included_players].shape==(len(classifier_prob),)\n",
    "                            sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                          'path': path,\n",
    "                                                          'num_included_players': num_included_players,                                                      \n",
    "                                                          'target': sensitivity[num_included_players][classifier_prob_argmax],\n",
    "                                                          'non-target': sensitivity[num_included_players][np.arange(len(sensitivity[num_included_players]))!=classifier_prob_argmax].mean(),\n",
    "                                                         })                    \n",
    "        sensitivity_result_dict[dataset_name][backbone_type]=sensitivity_dict_list#print(pd.DataFrame(sensitivity_dict_list).groupby(['explanation_method']).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01cb35ba",
   "metadata": {},
   "source": [
    "# Process estimationerror"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcab8610",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148c6f64",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimationerror_result_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    estimationerror_result_dict.setdefault(dataset_name, {})\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        estimationerror_result_dict[dataset_name].setdefault(backbone_type, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        \n",
    "        estimationerror_dict_list=[]\n",
    "        for explanation_method, estimationerror_save_dict in data_loaded_all[\"9_estimationerror\"][dataset_name][backbone_type].items():\n",
    "            for path, estimationerror in estimationerror_save_dict.items():        \n",
    "                classifier_prob_data=data_loaded_all['1_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                classifier_prob=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['prob']                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    pass\n",
    "#                     assert classifier_prob[0]>=0 and classifier_prob[0]<=1\n",
    "#                     for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "#                         if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "#                             assert estimationerror[num_included_players].shape==(1,)\n",
    "#                             estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "#                                                           'path': path,\n",
    "#                                                           'num_included_players': num_included_players,\n",
    "#                                                           'target': estimationerror[num_included_players][0] if classifier_prob[0]>=0.5 else None,\n",
    "#                                                           'non-target': estimationerror[num_included_players][0] if classifier_prob[0]<0.5 else None,\n",
    "#                                                          })\n",
    "\n",
    "#                         else:\n",
    "#                             assert estimationerror[num_included_players].shape==(1,)\n",
    "#                             estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "#                                                           'path': path,\n",
    "#                                                           'num_included_players': num_included_players,                                                      \n",
    "#                                                           'target': estimationerror[num_included_players][0] if classifier_prob[0]>=0.5 else None,\n",
    "#                                                           'non-target': estimationerror[num_included_players][0] if classifier_prob[0]<0.5 else None,\n",
    "#                                                          })                       \n",
    "                    \n",
    "                else:\n",
    "                    classifier_prob_argmax=np.argmax(classifier_prob)\n",
    "                    assert classifier_prob_argmax>=0 and classifier_prob_argmax<=len(classifier_prob)\n",
    "                    if path in data_loaded_all[\"9_estimationerror\"][dataset_name][backbone_type][\"kernelshap\"].keys():\n",
    "                        ground_truth=data_loaded_all[\"9_estimationerror\"][dataset_name][backbone_type][\"kernelshap\"][path][\"estimation\"][0].values\n",
    "                    else:\n",
    "                        continue\n",
    "                        \n",
    "                    \n",
    "                    if explanation_method==\"ours\":\n",
    "                        value=estimationerror['estimation'].T\n",
    "                        assert value.shape==(196, len(classifier_prob))\n",
    "                        \n",
    "                        estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                          'path': path,\n",
    "                                                          'all': ((value.T-ground_truth.T)**2).sum(axis=1)**(0.5),\n",
    "                                                          'target': ((value.T-ground_truth.T)[[(classifier_prob_argmax)]]**2).sum(axis=1)**(0.5),\n",
    "                                                          'non-target': ((value.T-ground_truth.T)[np.arange(ground_truth.shape[1])!=classifier_prob_argmax]**2).sum(axis=1)**(0.5),\n",
    "                                                          'target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                          'non-target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                          'target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                          'non-target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                         })\n",
    "                        \n",
    "                    elif explanation_method==\"kernelshap\":\n",
    "                        for iter_idx, (value, num_sample) in enumerate(zip(estimationerror['estimation'][1]['values'],\n",
    "                                                              estimationerror['estimation'][1]['iters'])):                        \n",
    "                            if num_sample<200000:\n",
    "                                assert value.shape==(196, len(classifier_prob))\n",
    "                                estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                                  'path': path,\n",
    "                                                                  'num_sample': num_sample,\n",
    "                                                                  'all': ((value.T-ground_truth.T)**2).sum(axis=1)**(0.5),\n",
    "                                                                  'target': ((value.T-ground_truth.T)[[(classifier_prob_argmax)]]**2).sum(axis=1)**(0.5),\n",
    "                                                                  'non-target': ((value.T-ground_truth.T)[np.arange(ground_truth.shape[1])!=classifier_prob_argmax]**2).sum(axis=1)**(0.5),\n",
    "                                                                  'target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                                  'non-target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                                  'target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                                  'non-target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                                 })\n",
    "                    elif explanation_method==\"kernelshapnopair\":\n",
    "                        if path==\"/homes/gws/username/.fastai/data/imagenette2-160/val/n03425413/n03425413_13231.JPEG\":\n",
    "                            continue                        \n",
    "                        print(estimationerror['estimation'][1]['iters'])\n",
    "                        for iter_idx, (value, num_sample) in enumerate(zip(estimationerror['estimation'][1]['values'],\n",
    "                                                              estimationerror['estimation'][1]['iters'])):                        \n",
    "                            \n",
    "                            if num_sample<200000:\n",
    "                                assert value.shape==(196, len(classifier_prob))\n",
    "                                estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                                  'path': path,\n",
    "                                                                  'num_sample': num_sample,\n",
    "                                                                  'all': ((value.T-ground_truth.T)**2).sum(axis=1)**(0.5),\n",
    "                                                                  'target': ((value.T-ground_truth.T)[[(classifier_prob_argmax)]]**2).sum(axis=1)**(0.5),\n",
    "                                                                  'non-target': ((value.T-ground_truth.T)[np.arange(ground_truth.shape[1])!=classifier_prob_argmax]**2).sum(axis=1)**(0.5),\n",
    "                                                                  'target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                                  'non-target_spearman_r': [stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                                  'target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax],\n",
    "                                                                  'non-target_pearson_r': [stats.pearsonr(value.T[class_idx], ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax],\n",
    "                                                                 })                              \n",
    "                    else:\n",
    "                        raise\n",
    "                        \n",
    "                    \n",
    "                    #print((len(ground_truth)))\n",
    "                    #dsds\n",
    "#                     for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "#                         if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "#                             assert estimationerror[num_included_players].shape==(len(classifier_prob),)\n",
    "#                             estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "#                                                           'path': path,\n",
    "#                                                           'num_included_players': num_included_players,\n",
    "#                                                           'target': estimationerror[num_included_players][classifier_prob_argmax],\n",
    "#                                                           'non-target': None,\n",
    "#                                                          })\n",
    "\n",
    "#                         else:\n",
    "#                             assert estimationerror[num_included_players].shape==(len(classifier_prob),)\n",
    "#                             estimationerror_dict_list.append({'explanation_method': explanation_method,\n",
    "#                                                           'path': path,\n",
    "#                                                           'num_included_players': num_included_players,                                                      \n",
    "#                                                           'target': estimationerror[num_included_players][classifier_prob_argmax],\n",
    "#                                                           'non-target': estimationerror[num_included_players][np.arange(len(estimationerror[num_included_players]))!=classifier_prob_argmax].mean(),\n",
    "#                                                          })                    \n",
    "        estimationerror_result_dict[dataset_name][backbone_type]=estimationerror_dict_list#print(pd.DataFrame(estimationerror_dict_list).groupby(['explanation_method']).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a544411e",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_list=[]\n",
    "for i in estimationerror_result_dict[\"ImageNette\"][\"vit_base_patch16_224\"]:\n",
    "    if i[\"path\"]==\"/homes/gws/username/.fastai/data/imagenette2-160/val/n03425413/n03425413_13231.JPEG\":\n",
    "        continue\n",
    "    else:\n",
    "        new_list.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "617d08be",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(new_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5627ee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimationerror_result_dict[\"ImageNette\"][\"vit_base_patch16_224\"]=new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cba75d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if path==\"/homes/gws/username/.fastai/data/imagenette2-160/val/n03425413/n03425413_13231.JPEG\":\n",
    "    continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18b56ec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats.spearmanr(a=value.T[classifier_prob_argmax], b=ground_truth.T[classifier_prob_argmax])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afba5d43",
   "metadata": {},
   "outputs": [],
   "source": [
    "([stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx==classifier_prob_argmax])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c07dd6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean([stats.spearmanr(a=value.T[class_idx], b=ground_truth.T[class_idx])[0] for class_idx in np.arange(ground_truth.shape[1]) if class_idx!=classifier_prob_argmax])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3d41daf",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats.spearmanr(a=value.T[np.arange(ground_truth.shape[1])!=classifier_prob_argmax], b=ground_truth.T[np.arange(ground_truth.shape[1])!=classifier_prob_argmax])[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac299a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "value.T[np.arange(ground_truth.shape[1])!=classifier_prob_argmax].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86320c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "(((value.T-ground_truth.T)[[(classifier_prob_argmax)]]**2).sum(axis=1)**(0.5)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76e57821",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loaded_all[\"9_estimationerror\"][dataset_name][backbone_type][\"kernelshap\"][path]['estimation'][1].keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b89851db",
   "metadata": {},
   "source": [
    "# Average ROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25d17af",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"] if backbone_type==\"vit_base_patch16_224\" else [\"ImageNette\"]:\n",
    "        #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                   Paired[12][3],\n",
    "                                                                                   Paired[12][5],\n",
    "                                                                                   Paired[12][7],\n",
    "                                                                                   Paired[12][9],\n",
    "                                                                                   Paired[12][11]\n",
    "                                                                                   ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "        \"\"\"\n",
    "        fig, axd = plt.subplot_mosaic(\n",
    "            [[f'{insert_delete}_{dataset_name}_target' for dataset_name in [\"ImageNette\", \"MURA\"] for insert_delete in ['insert', 'delete']],\n",
    "             [f'{insert_delete}_{dataset_name}_non-target' for dataset_name in [\"ImageNette\", \"MURA\"] for insert_delete in ['insert', 'delete']]],\n",
    "            figsize=(16, 7), constrained_layout=True)      \n",
    "        \"\"\"\n",
    "        \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(18,10))\n",
    "            subfigs = fig.subfigures(1, 2)            \n",
    "            \n",
    "            #subfigs[0].suptitle(\"Insertion\")   \n",
    "            ax_insertion = subfigs[0].subplots(2,1)\n",
    "\n",
    "            #subfigs[1].suptitle(\"Deletion\")\n",
    "            ax_deletion = subfigs[1].subplots(2,1)\n",
    "\n",
    "            axd={\"insert_target\": ax_insertion[0],\n",
    "                 \"insert_non-target\": ax_insertion[1],\n",
    "                 \"delete_target\": ax_deletion[0],\n",
    "                 \"delete_non-target\": ax_deletion[1]\n",
    "                }\n",
    "        elif dataset_name==\"MURA\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(18,6))\n",
    "            subfigs = fig.subfigures(1, 2)\n",
    "            \n",
    "            #subfigs[0].suptitle(\"Insertion\")\n",
    "            ax_insertion = subfigs[0].subplots(1,1)\n",
    "\n",
    "            #subfigs[1].suptitle(\"Deletion\")\n",
    "            ax_deletion = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={\"insert_target\": ax_insertion,\n",
    "                 \"delete_target\": ax_deletion,\n",
    "                }    \n",
    "        #fig.suptitle(dataset_name)\n",
    "        \n",
    "        for target_non_target in [\"target\", \"non-target\"] if dataset_name==\"ImageNette\" else [\"target\"]:\n",
    "            roc_summarized=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "            .groupby(['metric_mode', 'explanation_method'])[f'{target_non_target}_roc'].apply(lambda x: np.mean(x, axis=0))\n",
    "\n",
    "            for insert_delete in ['insert', 'delete']:\n",
    "                plot_key=f'{insert_delete}_{target_non_target}'\n",
    "                for idx1, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx2, explanation_method in enumerate(explanation_methods_category):\n",
    "                        #print(explanation_method)\n",
    "                        roc_data=roc_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if type(roc_data)==float:\n",
    "                            print(dataset_name, target_non_target, insert_delete, explanation_method, 'None!!!!!!')\n",
    "                        else:\n",
    "                            axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                               roc_data, \n",
    "                                               linestyle=[':','-.','-','-'][idx1],\n",
    "                                               linewidth=3,\n",
    "                                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "\n",
    "                if target_non_target==\"target\":  \n",
    "                    axd[plot_key].set_title(f'{insert_delete_mapper(insert_delete)}')\n",
    "                    \n",
    "                if dataset_name==\"ImageNette\":\n",
    "                    if insert_delete==\"insert\":\n",
    "                        if target_non_target==\"target\":\n",
    "                            axd[plot_key].set_ylabel(\"Probability (Target)\")\n",
    "                        elif target_non_target==\"non-target\":\n",
    "                            axd[plot_key].set_ylabel(\"Probability (Non-target)\")\n",
    "                        else:\n",
    "                            raise\n",
    "                elif dataset_name==\"MURA\":\n",
    "                    if insert_delete==\"insert\":\n",
    "                        if target_non_target==\"target\":\n",
    "                            axd[plot_key].set_ylabel(\"Probability (Abnormal)\")\n",
    "                        else:\n",
    "                            raise    \n",
    "                else:\n",
    "                    raise\n",
    "                    \n",
    "                if dataset_name==\"ImageNette\":\n",
    "                    if target_non_target==\"non-target\":\n",
    "                        if insert_delete==\"insert\":\n",
    "                            axd[plot_key].set_xlabel(\"# of Inserted Patches\")\n",
    "                        elif insert_delete==\"delete\":\n",
    "                            axd[plot_key].set_xlabel(\"# of Deleted Patches\")\n",
    "                        else:\n",
    "                            raise                        \n",
    "                elif dataset_name==\"MURA\":\n",
    "                    if insert_delete==\"insert\":\n",
    "                        axd[plot_key].set_xlabel(\"# of Inserted Patches\")\n",
    "                    elif insert_delete==\"delete\":\n",
    "                        axd[plot_key].set_xlabel(\"# of Deleted Patches\")\n",
    "                    else:\n",
    "                        raise \n",
    "                else:\n",
    "                    raise                        \n",
    "                    \n",
    "                if target_non_target==\"target\":\n",
    "                    axd[plot_key].set_ylim(0, 1.05)\n",
    "                axd[plot_key].set_xlim(-2, 200)                \n",
    "\n",
    "                axd[plot_key].spines['right'].set_visible(False)\n",
    "                axd[plot_key].spines['top'].set_visible(False)\n",
    "                #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "                axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "                axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "\n",
    "    #             if target_non_target==\"target\":\n",
    "    #                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "    #                                 bottom=False)       \n",
    "\n",
    "                axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "                axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                     \n",
    "              \n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(2)\n",
    "                    \n",
    "                    \n",
    "\n",
    "                \n",
    "\n",
    "        \n",
    "#     legend_elements = [Line2D([0], [0], color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx], linewidth=5, label=explanation_method_mapper(explanation_method))\n",
    "#                        for idx, explanation_method in enumerate([j for i in explanation_methods_main for j in i])]\n",
    "\n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-','-'][idx1],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(explanation_method))\n",
    "                             for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "        legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "        legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "        legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "        \n",
    "        #fig.subplots_adjust(bottom=0, top=0.9, hspace=0.5)\n",
    "        #fig.subplots_adjust(bottom=0, top=0.9)\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=5, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.15) if dataset_name==\"ImageNette\" else (0.5, -0.25))\n",
    "        \n",
    "        fig.savefig(f\"results/plots/average_roc_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/average_roc_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/average_roc_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/average_roc_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')\n",
    "\n",
    "        #non_target_roc=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        #.groupby(['metric_mode', 'explanation_method'])['non-target_roc'].apply(lambda x: np.mean(x, axis=0))        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff3bbdb9",
   "metadata": {},
   "source": [
    "# Sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c2b1d0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):    \n",
    "    #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "    for target_non_target in [\"target\", \"non-target\"]:\n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                    Paired[12][3],\n",
    "                                                                                    Paired[12][5],\n",
    "                                                                                    Paired[12][7],\n",
    "                                                                                    Paired[12][9],\n",
    "                                                                                    Paired[12][11]\n",
    "                                                                                    ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "        if target_non_target==\"target\" and backbone_type==\"vit_base_patch16_224\":\n",
    "\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(18, 6))\n",
    "            subfigs = fig.subfigures(1, 2)\n",
    "\n",
    "            ax_ImageNette = subfigs[0].subplots(1,1)\n",
    "            #subfigs[0].supylabel(\"Correlation between sum of attributions and output\", fontsize=18)\n",
    "\n",
    "            ax_MURA = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                 f\"MURA_{target_non_target}\": ax_MURA,\n",
    "                }\n",
    "        elif backbone_type==\"vit_small_patch16_224\" or target_non_target==\"non-target\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9, 6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "        cardinality_list=list(range(14, 196, 14))\n",
    "\n",
    "        for dataset_name in [\"ImageNette\", \"MURA\"] if (target_non_target==\"target\" and backbone_type==\"vit_base_patch16_224\") else [\"ImageNette\"]:\n",
    "            \n",
    "            sensitivity_summarized_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type]).groupby([\"explanation_method\", \"num_included_players\"])[target_non_target].mean()\n",
    "            sensitivity_summarized_std=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type]).groupby([\"explanation_method\", \"num_included_players\"])[target_non_target].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "            \n",
    "            plot_key=f'{dataset_name}_{target_non_target}'\n",
    "            for idx1, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                for idx2, explanation_method in enumerate(explanation_methods_category):    \n",
    "                    \n",
    "                    axd[plot_key].plot(cardinality_list,\n",
    "                                       sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values,\n",
    "                                       linestyle=[':','-.','-'][idx1],\n",
    "                                       linewidth=3,\n",
    "                                       c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2])\n",
    "                    \n",
    "#                     axd[plot_key].fill_between(cardinality_list,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values - sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values + sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2], \n",
    "#                                                alpha=0.2)                    \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylim(-0.1, 0.73)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylim(-0.3, 0.73)\n",
    "            else:\n",
    "                raise\n",
    "            axd[plot_key].set_xlim(14-2, 182+2)                          \n",
    "                    \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "            #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "\n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylabel(\"Correlation\", labelpad=-10)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylabel(\"Correlation (Non-target only)\", labelpad=-10)                \n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Patches\", labelpad=10)\n",
    "            \n",
    "            axd[plot_key].set_title(dataset_name)\n",
    "            \n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][idx1],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(explanation_method))\n",
    "                             for idx1, explanation_method_category in enumerate(explanation_method_main) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "        legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=4, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.25))              \n",
    "\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.png\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.svg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.jpg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.pdf\", bbox_inches='tight')        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c5cd85",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):    \n",
    "    #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "    for target_non_target in [\"target\", \"non-target\"]:\n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                    Paired[12][3],\n",
    "                                                                                    Paired[12][5],\n",
    "                                                                                    Paired[12][7],\n",
    "                                                                                    Paired[12][9],\n",
    "                                                                                    Paired[12][11]\n",
    "                                                                                    ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "        if target_non_target==\"target\" and backbone_type==\"vit_base_patch16_224\":\n",
    "\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(18, 5.5))\n",
    "            subfigs = fig.subfigures(1, 2)\n",
    "\n",
    "            ax_ImageNette = subfigs[0].subplots(1,1)\n",
    "            #subfigs[0].supylabel(\"Correlation between sum of attributions and output\", fontsize=18)\n",
    "\n",
    "            ax_MURA = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                 f\"MURA_{target_non_target}\": ax_MURA,\n",
    "                }\n",
    "        elif backbone_type==\"vit_small_patch16_224\" or target_non_target==\"non-target\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9, 5.5))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "        cardinality_list=list(range(14, 196, 14))\n",
    "\n",
    "        for dataset_name in [\"ImageNette\", \"MURA\"] if (target_non_target==\"target\" and backbone_type==\"vit_base_patch16_224\") else [\"ImageNette\"]:\n",
    "            \n",
    "            sensitivity_summarized_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type]).groupby([\"explanation_method\", \"num_included_players\"])[target_non_target].mean()\n",
    "            sensitivity_summarized_std=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type]).groupby([\"explanation_method\", \"num_included_players\"])[target_non_target].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "            \n",
    "            plot_key=f'{dataset_name}_{target_non_target}'\n",
    "            for idx1, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                for idx2, explanation_method in enumerate(explanation_methods_category):    \n",
    "                    \n",
    "                    axd[plot_key].plot(cardinality_list,\n",
    "                                       sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values,\n",
    "                                       linestyle=[':','-.','-'][idx1],\n",
    "                                       linewidth=3,\n",
    "                                       c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2])\n",
    "                    \n",
    "#                     axd[plot_key].fill_between(cardinality_list,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values - sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values + sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2], \n",
    "#                                                alpha=0.2)                    \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylim(-0.1, 0.73)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylim(-0.3, 0.73)\n",
    "            else:\n",
    "                raise\n",
    "            axd[plot_key].set_xlim(14-2, 182+2)                          \n",
    "                    \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "            #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "\n",
    "            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylabel(\"Correlation\", labelpad=2)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylabel(\"Correlation (Non-target only)\", labelpad=2)\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Patches\", labelpad=10)\n",
    "            \n",
    "            axd[plot_key].set_title(dataset_name)\n",
    "            \n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][idx1],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(explanation_method))\n",
    "                             for idx1, explanation_method_category in enumerate(explanation_method_main) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "        legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=4, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.25))              \n",
    "\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.png\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.svg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.jpg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/sensitivity_n_{backbone_type}_{target_non_target}.pdf\", bbox_inches='tight')        \n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f5434d6",
   "metadata": {},
   "source": [
    "# estimationerror"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a671a7c2",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def flatten_data(x_array):\n",
    "\n",
    "    if len(x_array)==1:\n",
    "        return x_array\n",
    "    else:\n",
    "        sample_len=len([x for x in x_array][0])\n",
    "        if all([len(x)==1 for x in x_array]):\n",
    "            x_array_new=np.array([x[0] for x in x_array])\n",
    "            return x_array_new\n",
    "        elif all([len(x)==sample_len for x in x_array]):\n",
    "            x_array_new=np.array([x for x_list in x_array for x in x_list])\n",
    "            return x_array_new\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):    \n",
    "    #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "    for target_non_target in [\"target\", \"non-target\"]:\n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                    Paired[12][3],\n",
    "                                                                                    Paired[12][5],\n",
    "                                                                                    Paired[12][7],\n",
    "                                                                                    Paired[12][9],\n",
    "                                                                                    Paired[12][11]\n",
    "                                                                                    ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "        if backbone_type==\"vit_base_patch16_224\" or target_non_target==\"target\":\n",
    "\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "            #subfigs[0].supylabel(\"Correlation between sum of attributions and output\", fontsize=18)\n",
    "\n",
    "            #ax_MURA = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "        elif backbone_type==\"vit_base_patch16_224\" or target_non_target==\"non-target\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "        for dataset_name in [\"ImageNette\"]:\n",
    "            estimationerror_summarized=pd.DataFrame(estimationerror_result_dict[dataset_name][backbone_type])\n",
    "            \n",
    "            plot_key=f'{dataset_name}_{target_non_target}'\n",
    "\n",
    "            \n",
    "            # ours\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target].apply(lambda x: (flatten_data(x)).mean())\n",
    "            ours_mean=flatten_data(estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"ours\"][target_non_target]).mean()\n",
    "#             axd[plot_key].axhline(y=ours_mean, xmin=0, xmax=1, \n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=3,\n",
    "#                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               [ours_mean]*len(error_mean.index.values),\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            \n",
    "            # kernelshap\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=2,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])            \n",
    "            \n",
    "#             sns.lineplot(data=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"],\n",
    "#                         x=\"num_sample\", y=\"target\", ax=axd[plot_key])\n",
    "            \n",
    "            # kernelshapnopair\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=1,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            for idx1, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                for idx2, explanation_method in enumerate(explanation_methods_category):    \n",
    "                    pass\n",
    "\n",
    "                    \n",
    "#                     axd[plot_key].fill_between(cardinality_list,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values - sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values + sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2], \n",
    "#                                                alpha=0.2)                    \n",
    "            if target_non_target==\"target\":\n",
    "                #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                axd[plot_key].set_ylim(1e-2, 1e+0)\n",
    "                pass\n",
    "            elif target_non_target==\"non-target\":\n",
    "                #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                axd[plot_key].set_ylim(1e-3, 1e+0)\n",
    "                pass\n",
    "            else:\n",
    "                raise\n",
    "            axd[plot_key].set_xlim(-1, None)\n",
    "            #axd[plot_key].set_xlim(80000, 100000)\n",
    "                    \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "            #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            from matplotlib.ticker import FuncFormatter\n",
    "            from matplotlib.ticker import StrMethodFormatter\n",
    "            axd[plot_key].xaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))\n",
    "            #axd[plot_key].xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{int(x/1000)}'))\n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylabel(\"L2 distance (Target)\", labelpad=0)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylabel(\"L2 distance (Non-target)\", labelpad=0)\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Evals\", labelpad=10)\n",
    "            \n",
    "            axd[plot_key].set_title(dataset_name)\n",
    "            \n",
    "            axd[plot_key].set_yscale('log')\n",
    "            \n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP + Paired Sampling\"),\n",
    "                           Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP\"),\n",
    "                          Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(\"ours\"))\n",
    "                          ]\n",
    "        #legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=4, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.1))              \n",
    "\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.png\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.svg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.jpg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.pdf\", bbox_inches='tight')        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44dc5bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):    \n",
    "    #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "    for target_non_target in [\"target\", \"non-target\"]:\n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                    Paired[12][3],\n",
    "                                                                                    Paired[12][5],\n",
    "                                                                                    Paired[12][7],\n",
    "                                                                                    Paired[12][9],\n",
    "                                                                                    Paired[12][11]\n",
    "                                                                                    ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "        if backbone_type==\"vit_base_patch16_224\" or target_non_target==\"target\":\n",
    "\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "            #subfigs[0].supylabel(\"Correlation between sum of attributions and output\", fontsize=18)\n",
    "\n",
    "            #ax_MURA = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "        elif backbone_type==\"vit_base_patch16_224\" or target_non_target==\"non-target\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "        for dataset_name in [\"ImageNette\"]:\n",
    "            estimationerror_summarized=pd.DataFrame(estimationerror_result_dict[dataset_name][backbone_type])\n",
    "            \n",
    "            plot_key=f'{dataset_name}_{target_non_target}'\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "            # ours\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            ours_mean=flatten_data(estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"ours\"][target_non_target+\"_pearson_r\"]).mean()\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               [ours_mean]*len(error_mean.index.values),\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            \n",
    "            # kernelshap\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=2,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])            \n",
    "            \n",
    "#             sns.lineplot(data=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"],\n",
    "#                         x=\"num_sample\", y=\"target\", ax=axd[plot_key])\n",
    "            \n",
    "            # kernelshapnopair\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=1,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])             \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "#             # ours\n",
    "#             estimationerror_summarized_select=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"]).mean()\n",
    "#             ours_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"ours\"][target_non_target+\"_pearson_r\"].mean()\n",
    "#             axd[plot_key].plot(estimationerror_summarized_select.index.values,\n",
    "#                                [ours_mean]*len(estimationerror_summarized_select.index.values),\n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=3,\n",
    "#                                c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            \n",
    "#             # kernelshap\n",
    "#             error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: x.mean())\n",
    "#             error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: 1.96*x.std()/np.sqrt(len(x)))\n",
    "#             axd[plot_key].plot(error_mean.index.values,\n",
    "#                                error_mean.values,\n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=2,\n",
    "#                                c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])  \n",
    "            \n",
    "#             axd[plot_key].fill_between(error_mean.index.values,\n",
    "#                                error_mean.values-error_std.values,\n",
    "#                                error_mean.values+error_std.values,                                       \n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=3,\n",
    "#                                alpha=0.3,\n",
    "#                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])            \n",
    "            \n",
    "# #             sns.lineplot(data=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"],\n",
    "# #                         x=\"num_sample\", y=\"target\", ax=axd[plot_key])\n",
    "            \n",
    "#             # kernelshapnopair\n",
    "#             error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: x.mean())\n",
    "#             error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_pearson_r\"].apply(lambda x: 1.96*x.std()/np.sqrt(len(x)))\n",
    "#             axd[plot_key].plot(error_mean.index.values,\n",
    "#                                error_mean.values,\n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=1,\n",
    "#                                c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "#             axd[plot_key].fill_between(error_mean.index.values,\n",
    "#                                error_mean.values-error_std.values,\n",
    "#                                error_mean.values+error_std.values,                                       \n",
    "#                                linestyle=[':','-.','-'][2],\n",
    "#                                linewidth=3,\n",
    "#                                alpha=0.3,\n",
    "#                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            for idx1, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                for idx2, explanation_method in enumerate(explanation_methods_category):    \n",
    "                    pass\n",
    "\n",
    "                    \n",
    "#                     axd[plot_key].fill_between(cardinality_list,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values - sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values + sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2], \n",
    "#                                                alpha=0.2)                    \n",
    "            if target_non_target==\"target\":\n",
    "                #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                axd[plot_key].set_ylim(0, 1e+0)\n",
    "                pass\n",
    "            elif target_non_target==\"non-target\":\n",
    "                #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                axd[plot_key].set_ylim(0, 1e+0)\n",
    "                pass\n",
    "            else:\n",
    "                raise\n",
    "            axd[plot_key].set_xlim(-1, None)\n",
    "            #axd[plot_key].set_xlim(80000, 100000)\n",
    "                    \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "            #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            from matplotlib.ticker import FuncFormatter\n",
    "            from matplotlib.ticker import StrMethodFormatter\n",
    "            axd[plot_key].xaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))\n",
    "            #axd[plot_key].xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{int(x/1000)}'))\n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylabel(\"Pearson correlation (Target)\", labelpad=0)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylabel(\"Pearson correlation (Non-target)\", labelpad=0)\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Evals\", labelpad=10)\n",
    "            \n",
    "            axd[plot_key].set_title(dataset_name)\n",
    "            \n",
    "            #axd[plot_key].set_yscale('log')\n",
    "            \n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP + Paired Sampling\"),\n",
    "                           Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP\"),\n",
    "                          Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(\"ours\"))\n",
    "                          ]\n",
    "        #legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=4, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.1))              \n",
    "\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.png\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.svg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.jpg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.pdf\", bbox_inches='tight')        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f75bc3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):    \n",
    "    #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "    for target_non_target in [\"target\", \"non-target\"]:\n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                    Paired[12][3],\n",
    "                                                                                    Paired[12][5],\n",
    "                                                                                    Paired[12][7],\n",
    "                                                                                    Paired[12][9],\n",
    "                                                                                    Paired[12][11]\n",
    "                                                                                    ]])     \n",
    "\n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 18\n",
    "\n",
    "        if backbone_type==\"vit_base_patch16_224\" or target_non_target==\"target\":\n",
    "\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "            #subfigs[0].supylabel(\"Correlation between sum of attributions and output\", fontsize=18)\n",
    "\n",
    "            #ax_MURA = subfigs[1].subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "        elif backbone_type==\"vit_base_patch16_224\" or target_non_target==\"non-target\":\n",
    "            fig = plt.figure(constrained_layout=True, figsize=(9,6))\n",
    "            subfigs = fig.subfigures(1, 1)\n",
    "\n",
    "            ax_ImageNette = subfigs.subplots(1,1)\n",
    "\n",
    "            axd={f\"ImageNette_{target_non_target}\": ax_ImageNette,\n",
    "                }\n",
    "\n",
    "        else:\n",
    "            raise\n",
    "\n",
    "        for dataset_name in [\"ImageNette\"]:\n",
    "            estimationerror_summarized=pd.DataFrame(estimationerror_result_dict[dataset_name][backbone_type])\n",
    "            \n",
    "            plot_key=f'{dataset_name}_{target_non_target}'\n",
    "\n",
    "            \n",
    "            # ours\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_spearman_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            ours_mean=flatten_data(estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"ours\"][target_non_target+\"_spearman_r\"]).mean()\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               [ours_mean]*len(error_mean.index.values),\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            \n",
    "            # kernelshap\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_spearman_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+\"_spearman_r\"].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=2,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])            \n",
    "            \n",
    "#             sns.lineplot(data=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"],\n",
    "#                         x=\"num_sample\", y=\"target\", ax=axd[plot_key])\n",
    "            \n",
    "            # kernelshapnopair\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_spearman_r\"].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+\"_spearman_r\"].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=1,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])   \n",
    "            \n",
    "            for idx1, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                for idx2, explanation_method in enumerate(explanation_methods_category):    \n",
    "                    pass\n",
    "\n",
    "                    \n",
    "#                     axd[plot_key].fill_between(cardinality_list,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values - sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                sensitivity_summarized_mean.loc[explanation_method].loc[cardinality_list].values + sensitivity_summarized_std.loc[explanation_method].loc[cardinality_list].values,\n",
    "#                                                color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2], \n",
    "#                                                alpha=0.2)                    \n",
    "            if target_non_target==\"target\":\n",
    "                #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                axd[plot_key].set_ylim(0, 1e+0)\n",
    "                pass\n",
    "            elif target_non_target==\"non-target\":\n",
    "                #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                axd[plot_key].set_ylim(0, 1e+0)\n",
    "                pass\n",
    "            else:\n",
    "                raise\n",
    "            axd[plot_key].set_xlim(-1, None)\n",
    "            #axd[plot_key].set_xlim(80000, 100000)\n",
    "                    \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "            #axd[plot_key].spines['bottom'].set_visible(False)\n",
    "\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            from matplotlib.ticker import FuncFormatter\n",
    "            from matplotlib.ticker import StrMethodFormatter\n",
    "            axd[plot_key].xaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))\n",
    "            #axd[plot_key].xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{int(x/1000)}'))\n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)                \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "            if target_non_target==\"target\":\n",
    "                axd[plot_key].set_ylabel(\"Spearman correlation (Target)\", labelpad=0)\n",
    "            elif target_non_target==\"non-target\":\n",
    "                axd[plot_key].set_ylabel(\"Spearman correlation (Non-target)\", labelpad=0)\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Evals\", labelpad=10)\n",
    "            \n",
    "            axd[plot_key].set_title(dataset_name)\n",
    "            \n",
    "            #axd[plot_key].set_yscale('log')\n",
    "            \n",
    "        legend_elements = [Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP + Paired Sampling\"),\n",
    "                           Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0],\n",
    "                                  linewidth=5,\n",
    "                                  label=\"KernelSHAP\"),\n",
    "                          Line2D([0], [0],\n",
    "                                  linestyle=[':','-.','-'][2],\n",
    "                                  color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2],\n",
    "                                  linewidth=5,\n",
    "                                  label=explanation_method_mapper(\"ours\"))\n",
    "                          ]\n",
    "        #legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "        fig.legend(handles=legend_elements, \n",
    "                    ncol=4, \n",
    "                    handlelength=3,\n",
    "                    handletextpad=0.6, \n",
    "                    columnspacing=1.5,\n",
    "                    loc='lower center', bbox_to_anchor=(0.5, -0.1))              \n",
    "\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.png\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.svg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.jpg\", bbox_inches='tight')\n",
    "        #fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{target_non_target}.pdf\", bbox_inches='tight')        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e49c04b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate({'vit_base_patch16_224': []}.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])   \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 15\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(18, 5.5*3))\n",
    "    box1 = gridspec.GridSpec(3, 1, wspace=0, hspace=0.3)\n",
    "    \n",
    "    axd={}\n",
    "    for idx1, stage in enumerate([\"l2distance\", \"pearson\", \"spearman\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=box1[idx1], wspace=0.15, hspace=0)\n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):\n",
    "            box3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "            ax = plt.Subplot(fig, box3[0])\n",
    "            fig.add_subplot(ax)\n",
    "            \n",
    "            plot_key=f\"{stage}_{target_non_target}\"\n",
    "            axd[plot_key]=ax\n",
    "    \n",
    "    estimationerror_summarized=pd.DataFrame(estimationerror_result_dict[dataset_name][backbone_type])\n",
    "    for idx1, stage in enumerate([\"l2distance\", \"pearson\", \"spearman\"]):\n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):\n",
    "            print(stage)\n",
    "            \n",
    "            plot_key=f\"{stage}_{target_non_target}\"\n",
    "            \n",
    "            if stage==\"l2distance\":\n",
    "                stage_key=\"\"\n",
    "            elif stage==\"pearson\":\n",
    "                stage_key=\"_pearson_r\"\n",
    "            elif stage==\"spearman\":\n",
    "                stage_key=\"_spearman_r\"\n",
    "            else:\n",
    "                raise\n",
    "                                \n",
    "            # ours\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+stage_key].apply(lambda x: (flatten_data(x)).mean())\n",
    "            ours_mean=flatten_data(estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"ours\"][target_non_target+stage_key]).mean()\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               [ours_mean]*len(error_mean.index.values),\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2])             \n",
    "            \n",
    "            # kernelshap\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+stage_key].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"].groupby([\"num_sample\"])[target_non_target+stage_key].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=2,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1])            \n",
    "            \n",
    "#             sns.lineplot(data=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshap\"],\n",
    "#                         x=\"num_sample\", y=\"target\", ax=axd[plot_key])\n",
    "            \n",
    "            # kernelshapnopair\n",
    "            error_mean=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+stage_key].apply(lambda x: (flatten_data(x)).mean())\n",
    "            error_std=estimationerror_summarized[estimationerror_summarized[\"explanation_method\"]==\"kernelshapnopair\"].groupby([\"num_sample\"])[target_non_target+stage_key].apply(lambda x: 1.96*(flatten_data(x)).std()/np.sqrt(len(flatten_data(x))))\n",
    "            axd[plot_key].plot(error_mean.index.values,\n",
    "                               error_mean.values,\n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=1,\n",
    "                               c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])  \n",
    "            \n",
    "            axd[plot_key].fill_between(error_mean.index.values,\n",
    "                               error_mean.values-error_std.values,\n",
    "                               error_mean.values+error_std.values,                                       \n",
    "                               linestyle=[':','-.','-'][2],\n",
    "                               linewidth=3,\n",
    "                               alpha=0.3,\n",
    "                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0])\n",
    "            \n",
    "\n",
    "            if stage==\"l2distance\":\n",
    "                axd[plot_key].set_title(target_non_target_mapper(target_non_target), pad=15)\n",
    "    \n",
    "            if stage==\"l2distance\":\n",
    "                if target_non_target==\"target\":\n",
    "                    #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                    axd[plot_key].set_ylim(1e-2, 1e+0)\n",
    "                    pass\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                    axd[plot_key].set_ylim(1e-3, 1e+0)\n",
    "                    pass\n",
    "                else:\n",
    "                    raise\n",
    "                axd[plot_key].set_xlim(-1, None) \n",
    "            elif stage==\"pearson\":\n",
    "                if target_non_target==\"target\":\n",
    "                    #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                    axd[plot_key].set_ylim(0, 1e+0)\n",
    "                    pass\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                    axd[plot_key].set_ylim(0, 1e+0)\n",
    "                    pass\n",
    "                else:\n",
    "                    raise\n",
    "                axd[plot_key].set_xlim(-1, None)\n",
    "            elif stage==\"spearman\":\n",
    "                if target_non_target==\"target\":\n",
    "                    #axd[plot_key].set_ylim(-0.1, 1)\n",
    "                    axd[plot_key].set_ylim(1e-2, 1e+0)\n",
    "                    pass\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    #axd[plot_key].set_ylim(-0.3, 1)\n",
    "                    axd[plot_key].set_ylim(1e-3, 1e+0)\n",
    "                    pass\n",
    "                else:\n",
    "                    raise\n",
    "                axd[plot_key].set_xlim(-1, None)  \n",
    "            else:\n",
    "                raise        \n",
    "\n",
    "            if stage==\"l2distance\":\n",
    "                axd[plot_key].set_yscale('log')\n",
    "            \n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            from matplotlib.ticker import FuncFormatter\n",
    "            from matplotlib.ticker import StrMethodFormatter\n",
    "            axd[plot_key].xaxis.set_major_formatter(StrMethodFormatter('{x:,.0f}'))\n",
    "            #axd[plot_key].xaxis.set_major_formatter(FuncFormatter(lambda x, pos: f'{int(x/1000)}'))\n",
    "\n",
    "#             if target_non_target==\"target\":\n",
    "#                 axd[plot_key].tick_params(axis='x', which='both',\n",
    "#                                 bottom=False)       \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)\n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "                \n",
    "            if stage==\"l2distance\":\n",
    "                if target_non_target==\"target\":\n",
    "                    axd[plot_key].set_ylabel(\"L2 distance\")#, labelpad=0)\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    axd[plot_key].set_ylabel(\"L2 distance\")#, labelpad=0)\n",
    "                else:\n",
    "                    raise\n",
    "            elif stage==\"pearson\":\n",
    "                if target_non_target==\"target\":\n",
    "                    axd[plot_key].set_ylabel(\"Pearson correlation\", labelpad=15)\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    axd[plot_key].set_ylabel(\"Pearson correlation\", labelpad=15)\n",
    "                else:\n",
    "                    raise\n",
    "            elif stage==\"spearman\":\n",
    "                if target_non_target==\"target\":\n",
    "                    axd[plot_key].set_ylabel(\"Spearman correlation\", labelpad=15)\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    axd[plot_key].set_ylabel(\"Spearman correlation\", labelpad=15)\n",
    "                else:\n",
    "                    raise\n",
    "            else:\n",
    "                raise                \n",
    "\n",
    "            axd[plot_key].set_xlabel(\"# Evals\", labelpad=5)\n",
    "\n",
    "#     legend_elements = [Line2D([0], [0],\n",
    "#                               linestyle=[':','-.','-','-'][idx1],\n",
    "#                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "#                               linewidth=3,\n",
    "#                               label=explanation_method_mapper(explanation_method))\n",
    "#                          for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    \n",
    "    \n",
    "#     legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "#     legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "#     legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    \n",
    "#     fig.legend(handles=legend_elements, \n",
    "#                 ncol=5, \n",
    "#                 handlelength=3,\n",
    "#                 handletextpad=0.6, \n",
    "#                 columnspacing=1.5,\n",
    "#                 fontsize=15,\n",
    "#                 loc='lower center', bbox_to_anchor=(0.5, 0.01))    \n",
    "    \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-'][2],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][1],\n",
    "                              linewidth=5,\n",
    "                              label=\"KernelSHAP + Paired Sampling\"),\n",
    "                       Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-'][2],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][0],\n",
    "                              linewidth=5,\n",
    "                              label=\"KernelSHAP\"),\n",
    "                      Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-'][2],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][2],\n",
    "                              linewidth=5,\n",
    "                              label=explanation_method_mapper(\"ours\"))\n",
    "                      ]\n",
    "    #legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "\n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=3, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, 0.05))\n",
    "    \n",
    "    fig.show()\n",
    "    \n",
    "    fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/estimationerror_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1beb27d",
   "metadata": {},
   "source": [
    "# Table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a42d7ad",
   "metadata": {},
   "source": [
    "## only kernelshap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0092198c",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_df=pd.DataFrame(roc_auc_result_dict[\"ImageNette\"][\"vit_base_patch16_224\"])\n",
    "auc_df_100=auc_df[auc_df[\"path\"].isin(auc_df[auc_df[\"explanation_method\"]==\"kernelshap\"]['path'].unique())]\n",
    "len(auc_df), len(auc_df_100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2116df31",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_df_100\\\n",
    ".groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].mean().loc[[\"delete\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba9d18b",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_sensitivity_table_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate({\"vit_base_patch16_224\":[]}.items()):\n",
    "    auc_sensitivity_table_dict.setdefault(backbone_type, {})\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\",] if backbone_type==\"vit_base_patch16_224\" else [\"ImageNette\"]:\n",
    "        auc_sensitivity_table_dict[backbone_type].setdefault(dataset_name, {})\n",
    "                \n",
    "        # AUC mean, std\n",
    "        auc_df=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\n",
    "        auc_df_100=auc_df[auc_df[\"path\"].isin(auc_df[auc_df[\"explanation_method\"]==\"kernelshap\"]['path'].unique())]\n",
    "        \n",
    "        roc_auc_table_mean=auc_df\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].mean()\n",
    "        roc_auc_table_mean=roc_auc_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        roc_auc_table_std=auc_df\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        roc_auc_table_std=roc_auc_table_std.add_suffix('_std')\n",
    "        \n",
    "        auc_sensitivity_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), roc_auc_table_std.loc['insert'].add_prefix('insert_'),\n",
    "                                         roc_auc_table_mean.loc['delete'].add_prefix('delete_'), roc_auc_table_std.loc['delete'].add_prefix('delete_'),\n",
    "                                        ], axis=1)\n",
    "        for subset_mode in [\"main\", \"supple\"]:\n",
    "            auc_sensitivity_table_dict[backbone_type][dataset_name].setdefault(subset_mode, {})\n",
    "            if subset_mode==\"main\":\n",
    "                #auc_sensitivity_table_select=auc_sensitivity_table.loc[explanation_methods_main]\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_main_random_kernelshap for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"main\"))\n",
    "            elif subset_mode==\"supple\":\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_supple_random_kernelshap for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"supple\"))\n",
    "            else:\n",
    "                raise\n",
    "            #print(auc_sensitivity_table_select)\n",
    "            for target_non_target in [\"target\", \"non-target\"]:\n",
    "                #print('sdsdsd')\n",
    "                #print(auc_sensitivity_table_select)\n",
    "                auc_sensitivity_table_select_format=pd.concat([auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"insert_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]), axis=1),\n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"delete_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]), axis=1),\n",
    "\n",
    "                                                               ], axis=1)\n",
    "                \n",
    "                #whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']        \n",
    "                auc_sensitivity_table_select_format.columns=pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Ins. (↑)\", \"Del. (↓)\"]])\n",
    "                auc_sensitivity_table_dict[backbone_type][dataset_name][subset_mode][target_non_target]=\\\n",
    "                auc_sensitivity_table_select_format\n",
    "                #print(auc_sensitivity_table_select_format.add_suffix(target_non_target),\n",
    "                #     auc_sensitivity_table_select)\n",
    "                #print('----------------------------------------------------------------')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1954185e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def df_to_latex(table_df, caption, explanation_methods_category_list):\n",
    "    latex_output=table_df.style.to_latex(hrules=True, \n",
    "                                         multicol_align='c')\n",
    "    for explanation_methods_category in explanation_methods_category_list[1:]:\n",
    "        latex_output=latex_output.replace('\\n'+explanation_methods_category[0], \n",
    "                                          '\\\\midrule\\n'+explanation_methods_category[0], 1)\n",
    "    latex_output_split=latex_output.split('\\n')\n",
    "    if latex_output_split[2].count(\"multicolumn\")==1:\n",
    "        latex_output_split.insert(3, '\\\\cmidrule(lr){2-4}')\n",
    "    elif latex_output_split[2].count(\"multicolumn\")==2:\n",
    "        latex_output_split.insert(3, '\\\\cmidrule(lr){2-4} \\cmidrule(lr){5-7}')\n",
    "    else:\n",
    "        raise\n",
    "    latex_output='\\n'.join(latex_output_split)\n",
    "    latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "    latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "    latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "    latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "    latex_output=latex_output.replace(\"bfe\",\"}\")    \n",
    "    latex_output=latex_output.replace(\"grays\",\"\\\\textcolor{gray}{\")\n",
    "    latex_output=latex_output.replace(\"graye\",\"}\")        \n",
    "    latex_output='% \\\\begin{scriptsize}\\n' + latex_output + '% \\\\end{scriptsize}\\n'\n",
    "    latex_output='\\\\begin{small}\\n' + latex_output + '\\\\end{small}\\n'\n",
    "    latex_output='\\\\begin{center}\\n' + latex_output + '\\\\end{center}\\n'\n",
    "    latex_output='\\\\vskip 0.01in\\n' + latex_output\n",
    "    latex_output='\\\\caption{{{}}}\\n'.format(caption) + latex_output\n",
    "    latex_output='\\\\begin{table}\\n' + latex_output + '\\\\end{table}\\n'\n",
    "    latex_output=latex_output.replace('{lllllll}','{lcccccc}')\n",
    "    return latex_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84e10dd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"ImageNette\"][\"supple\"][\"target\"],\n",
    "           auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"ImageNette\"][\"supple\"][\"non-target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"supple\") for j in i] for i in explanation_method_supple_random_kernelshap]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"main target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1090947d",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c1c1587",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b06b5b2b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d8f798",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06508daa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74d1e6ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c6e2dcb9",
   "metadata": {},
   "source": [
    "## classifiermask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "096a0104",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "auc_sensitivity_table_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    auc_sensitivity_table_dict.setdefault(backbone_type, {})\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\"]:\n",
    "        auc_sensitivity_table_dict[backbone_type].setdefault(dataset_name, {})\n",
    "        \n",
    "        auc_df=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\n",
    "        auc_df_100=auc_df[auc_df[\"path\"].isin(auc_df[auc_df[\"explanation_method\"]==\"kernelshap\"]['path'].unique())]\n",
    "        # AUC mean, std\n",
    "        roc_auc_table_mean=auc_df_100\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].mean()\n",
    "        roc_auc_table_mean=roc_auc_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        roc_auc_table_std=auc_df_100\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        roc_auc_table_std=roc_auc_table_std.add_suffix('_std')\n",
    "        \n",
    "        # Sensitivity mean, std\n",
    "        sensitivity_table_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])[['target','non-target']].mean()\n",
    "        sensitivity_table_mean=sensitivity_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        sensitivity_table_std=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])[['target','non-target']].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        sensitivity_table_std=sensitivity_table_std.add_suffix('_std')\n",
    "        \n",
    "        \n",
    "        auc_sensitivity_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), roc_auc_table_std.loc['insert'].add_prefix('insert_'),\n",
    "                                         roc_auc_table_mean.loc['delete'].add_prefix('delete_'), roc_auc_table_std.loc['delete'].add_prefix('delete_'),\n",
    "                                         sensitivity_table_mean.loc['all'].add_prefix('sensitivity_'), sensitivity_table_std.loc['all'].add_prefix('sensitivity_')\n",
    "                                        ], axis=1)\n",
    "        for subset_mode in [\"main\", \"supple\"]:\n",
    "            auc_sensitivity_table_dict[backbone_type][dataset_name].setdefault(subset_mode, {})\n",
    "            if subset_mode==\"main\":\n",
    "                #auc_sensitivity_table_select=auc_sensitivity_table.loc[explanation_methods_main]\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_main_random for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"main\"))\n",
    "            elif subset_mode==\"supple\":\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_supple_random for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"supple\"))\n",
    "            else:\n",
    "                raise\n",
    "            #print(auc_sensitivity_table_select)\n",
    "            for target_non_target in [\"target\", \"non-target\"]:\n",
    "                #print('sdsdsd')\n",
    "                #print(auc_sensitivity_table_select)\n",
    "                auc_sensitivity_table_select_format=pd.concat([auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"insert_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'grays{:.3f} ({:.3f})graye'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]) if (x[f\"insert_{target_non_target}_auc_mean\"]-x[f\"insert_{target_non_target}_auc_std\"])<=((auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean']+auc_sensitivity_table_select[f'insert_{target_non_target}_auc_std'])[\"Random\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]) if x[f\"insert_{target_non_target}_auc_mean\"]>=(auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'][~auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]), axis=1),\n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"delete_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'grays{:.3f} ({:.3f})graye'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]) if (x[f\"delete_{target_non_target}_auc_mean\"]+x[f\"delete_{target_non_target}_auc_std\"])>=((auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean']-auc_sensitivity_table_select[f'delete_{target_non_target}_auc_std'])[\"Random\"]) else                                                                                                  \n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]) if x[f\"delete_{target_non_target}_auc_mean\"]<=((auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'])[~auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'].isnull()]).min() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]), axis=1),                                        \n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"sensitivity_{target_non_target}_mean\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]) if x[f\"sensitivity_{target_non_target}_mean\"]>=(auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'][~auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]), axis=1)  \n",
    "                                                               ], axis=1)\n",
    "                \n",
    "                #whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']        \n",
    "                auc_sensitivity_table_select_format.columns=pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Ins. (↑)\", \"Del. (↓)\", \"Faith. (↑)\"]])\n",
    "                auc_sensitivity_table_dict[backbone_type][dataset_name][subset_mode][target_non_target]=\\\n",
    "                auc_sensitivity_table_select_format\n",
    "                #print(auc_sensitivity_table_select_format.add_suffix(target_non_target),\n",
    "                #     auc_sensitivity_table_select)\n",
    "                #print('----------------------------------------------------------------')\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "083ccac1",
   "metadata": {},
   "source": [
    "## all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70efca6b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "auc_sensitivity_table_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    auc_sensitivity_table_dict.setdefault(backbone_type, {})\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"] if backbone_type==\"vit_base_patch16_224\" else [\"ImageNette\"]:\n",
    "        auc_sensitivity_table_dict[backbone_type].setdefault(dataset_name, {})\n",
    "        # AUC mean, std\n",
    "        roc_auc_table_mean=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].mean()\n",
    "        roc_auc_table_mean=roc_auc_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        roc_auc_table_std=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        roc_auc_table_std=roc_auc_table_std.add_suffix('_std')\n",
    "        \n",
    "        # Sensitivity mean, std\n",
    "        sensitivity_table_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])[['target','non-target']].mean()\n",
    "        sensitivity_table_mean=sensitivity_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        sensitivity_table_std=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])[['target','non-target']].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        sensitivity_table_std=sensitivity_table_std.add_suffix('_std')\n",
    "        \n",
    "        \n",
    "        auc_sensitivity_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), roc_auc_table_std.loc['insert'].add_prefix('insert_'),\n",
    "                                         roc_auc_table_mean.loc['delete'].add_prefix('delete_'), roc_auc_table_std.loc['delete'].add_prefix('delete_'),\n",
    "                                         sensitivity_table_mean.loc['all'].add_prefix('sensitivity_'), sensitivity_table_std.loc['all'].add_prefix('sensitivity_')\n",
    "                                        ], axis=1)\n",
    "        for subset_mode in [\"main\", \"supple\"]:\n",
    "            auc_sensitivity_table_dict[backbone_type][dataset_name].setdefault(subset_mode, {})\n",
    "            if subset_mode==\"main\":\n",
    "                #auc_sensitivity_table_select=auc_sensitivity_table.loc[explanation_methods_main]\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_main_random for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"main\"))\n",
    "            elif subset_mode==\"supple\":\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_method_supple_random for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"supple\"))\n",
    "            else:\n",
    "                raise\n",
    "            #print(auc_sensitivity_table_select)\n",
    "            for target_non_target in [\"target\", \"non-target\"]:\n",
    "                #print('sdsdsd')\n",
    "                #print(auc_sensitivity_table_select)\n",
    "                auc_sensitivity_table_select_format=pd.concat([auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"insert_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'grays{:.3f} ({:.3f})graye'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]) if (x[f\"insert_{target_non_target}_auc_mean\"]-x[f\"insert_{target_non_target}_auc_std\"])<=((auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean']+auc_sensitivity_table_select[f'insert_{target_non_target}_auc_std'])[\"Random\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]) if x[f\"insert_{target_non_target}_auc_mean\"]>=(auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'][~auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]), axis=1),\n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"delete_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'grays{:.3f} ({:.3f})graye'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]) if (x[f\"delete_{target_non_target}_auc_mean\"]+x[f\"delete_{target_non_target}_auc_std\"])>=((auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean']-auc_sensitivity_table_select[f'delete_{target_non_target}_auc_std'])[\"Random\"]) else                                                                                                  \n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]) if x[f\"delete_{target_non_target}_auc_mean\"]<=((auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'])[~auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'].isnull()]).min() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]), axis=1),                                        \n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"sensitivity_{target_non_target}_mean\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]) if x[f\"sensitivity_{target_non_target}_mean\"]>=(auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'][~auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]), axis=1)  \n",
    "                                                               ], axis=1)\n",
    "                \n",
    "                #whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']        \n",
    "                auc_sensitivity_table_select_format.columns=pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Ins. (↑)\", \"Del. (↓)\", \"Faith. (↑)\"]])\n",
    "                auc_sensitivity_table_dict[backbone_type][dataset_name][subset_mode][target_non_target]=\\\n",
    "                auc_sensitivity_table_select_format\n",
    "                #print(auc_sensitivity_table_select_format.add_suffix(target_non_target),\n",
    "                #     auc_sensitivity_table_select)\n",
    "                #print('----------------------------------------------------------------')\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e66c5f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def df_to_latex(table_df, caption, explanation_methods_category_list):\n",
    "    latex_output=table_df.style.to_latex(hrules=True, \n",
    "                                         multicol_align='c')\n",
    "    for explanation_methods_category in explanation_methods_category_list[1:]:\n",
    "        latex_output=latex_output.replace('\\n'+explanation_methods_category[0], \n",
    "                                          '\\\\midrule\\n'+explanation_methods_category[0], 1)\n",
    "    latex_output_split=latex_output.split('\\n')\n",
    "    if latex_output_split[2].count(\"multicolumn\")==1:\n",
    "        latex_output_split.insert(3, '\\\\cmidrule(lr){2-4}')\n",
    "    elif latex_output_split[2].count(\"multicolumn\")==2:\n",
    "        latex_output_split.insert(3, '\\\\cmidrule(lr){2-4} \\cmidrule(lr){5-7}')\n",
    "    else:\n",
    "        raise\n",
    "    latex_output='\\n'.join(latex_output_split)\n",
    "    latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "    latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "    latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "    latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "    latex_output=latex_output.replace(\"bfe\",\"}\")    \n",
    "    latex_output=latex_output.replace(\"grays\",\"\\\\textcolor{gray}{\")\n",
    "    latex_output=latex_output.replace(\"graye\",\"}\")        \n",
    "    latex_output='% \\\\begin{scriptsize}\\n' + latex_output + '% \\\\end{scriptsize}\\n'\n",
    "    latex_output='\\\\begin{small}\\n' + latex_output + '\\\\end{small}\\n'\n",
    "    latex_output='\\\\begin{center}\\n' + latex_output + '\\\\end{center}\\n'\n",
    "    latex_output='\\\\vskip 0.01in\\n' + latex_output\n",
    "    latex_output='\\\\caption{{{}}}\\n'.format(caption) + latex_output\n",
    "    latex_output='\\\\begin{table}\\n' + latex_output + '\\\\end{table}\\n'\n",
    "    latex_output=latex_output.replace('{lllllll}','{lcccccc}')\n",
    "    return latex_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771dc03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"ImageNette\"][\"main\"][\"target\"],\n",
    "           auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"MURA\"][\"main\"][\"target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"main target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd160421",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"ImageNette\"][\"main\"][\"non-target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"main non-target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334cb0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_base_patch16_224\"][\"MURA\"][\"main\"][\"non-target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"supple non-target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd616057",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_small_patch16_224\"][\"ImageNette\"][\"supple\"][\"target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"main target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b1fb2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_df=pd.concat([auc_sensitivity_table_dict[\"vit_small_patch16_224\"][\"ImageNette\"][\"supple\"][\"non-target\"]], axis=1)\n",
    "\n",
    "table_df.index.name = None\n",
    "\n",
    "explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "print(df_to_latex(table_df, \n",
    "                  caption=\"main target\",\n",
    "                  explanation_methods_category_list = explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2098ccc",
   "metadata": {},
   "source": [
    "# Processs ROAR-noretraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec734542",
   "metadata": {},
   "outputs": [],
   "source": [
    "noretraining_result_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    noretraining_result_dict.setdefault(dataset_name, {})\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        noretraining_result_dict[dataset_name].setdefault(backbone_type, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        noretraining_dict_list=[]\n",
    "        for explanation_method, noretraining_save_dict in data_loaded_all[\"6_noretraining\"][dataset_name][backbone_type].items():\n",
    "            for path, noretraining_dict in noretraining_save_dict.items():\n",
    "                \n",
    "                classifier_prob_data=data_loaded_all['1_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                classifier_prob=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['prob']\n",
    "                label=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['label']                \n",
    "                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    if label==1:\n",
    "                        for metric_mode in ['insert', 'delete']:\n",
    "                            roc=noretraining_dict[metric_mode]\n",
    "                            if explanation_method==\"random\":\n",
    "                                assert roc.shape==(10, 1, 196+1)\n",
    "                                noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[:, 0]>0.5).astype(int)==label).astype(float).mean(axis=0)\n",
    "                                                             })\n",
    "                            elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                                assert roc.shape==(1, 196+1)\n",
    "                                noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[0]>0.5).astype(int)==label).astype(float)\n",
    "                                                             })\n",
    "                            else:\n",
    "                                assert roc.shape==(1, 196+1)\n",
    "                                noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[0]>0.5).astype(int)==label).astype(float)\n",
    "                                                             })\n",
    "                            \n",
    "                else:\n",
    "                    for metric_mode in ['insert', 'delete']:\n",
    "                        roc=noretraining_dict[metric_mode]   \n",
    "                        if explanation_method==\"random\":\n",
    "                            assert roc.shape==(10, len(classifier_prob), 196+1)\n",
    "                            noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,\n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=1)==label).astype(float).mean(axis=0)\n",
    "                                                         })                            \n",
    "                        elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,                                      \n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                         })\n",
    "                        else:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            noretraining_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,                                      \n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                         })\n",
    "                            \n",
    "        noretraining_result_dict[dataset_name][backbone_type]=noretraining_dict_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c458b8bf",
   "metadata": {},
   "source": [
    "# Process ROAR-classifiermasked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "783abbf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiermasked_result_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    classifiermasked_result_dict.setdefault(dataset_name, {})\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        classifiermasked_result_dict[dataset_name].setdefault(backbone_type, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        classifiermasked_dict_list=[]\n",
    "        for explanation_method, classifiermasked_save_dict in data_loaded_all[\"7_classifiermasked\"][dataset_name][backbone_type].items():\n",
    "            for path, classifiermasked_dict in classifiermasked_save_dict.items():\n",
    "                \n",
    "                classifier_prob_data=data_loaded_all['1_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                classifier_prob=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['prob']\n",
    "                label=classifier_prob_data[adapt_path(path, list(classifier_prob_data.keys()))]['label']                \n",
    "                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    if label==1:\n",
    "                        for metric_mode in ['insert', 'delete']:\n",
    "                            roc=classifiermasked_dict[metric_mode]\n",
    "                            if explanation_method==\"random\":\n",
    "                                assert roc.shape==(10, 1, 196+1)\n",
    "                                classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[:, 0]>0.5).astype(int)==label).astype(float).mean(axis=0)\n",
    "                                                             })\n",
    "                            elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                                assert roc.shape==(1, 196+1)\n",
    "                                classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[0]>0.5).astype(int)==label).astype(float)\n",
    "                                                             })\n",
    "                            else:\n",
    "                                assert roc.shape==(1, 196+1)\n",
    "                                classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                               'metric_mode': metric_mode,\n",
    "                                                               'path': path,\n",
    "                                                               'accuracy': ((roc[0]>0.5).astype(int)==label).astype(float)\n",
    "                                                             })\n",
    "                            \n",
    "                else:\n",
    "                    for metric_mode in ['insert', 'delete']:\n",
    "                        roc=classifiermasked_dict[metric_mode]   \n",
    "                        if explanation_method==\"random\":\n",
    "                            assert roc.shape==(10, len(classifier_prob), 196+1)\n",
    "                            classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,\n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=1)==label).astype(float).mean(axis=0)\n",
    "                                                         })                            \n",
    "                        elif explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,                                      \n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                         })\n",
    "                        else:\n",
    "                            assert roc.shape==(len(classifier_prob), 196+1)\n",
    "                            classifiermasked_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                           'metric_mode': metric_mode,                                      \n",
    "                                                           'path': path,\n",
    "                                                           'accuracy': (roc.argmax(axis=0)==label).astype(float)\n",
    "                                                         })\n",
    "                            \n",
    "        classifiermasked_result_dict[dataset_name][backbone_type]=classifiermasked_dict_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d10d074",
   "metadata": {},
   "source": [
    "# ROAR Plotting-three metrics-insertion/deletion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f343091",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])   \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 15\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(18, 5*3))\n",
    "    box1 = gridspec.GridSpec(3, 1, wspace=0, hspace=0.3)\n",
    "    \n",
    "    axd={}\n",
    "    for idx1, stage in enumerate([\"noretraining\", \"classifiermask\", \"retraining\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=box1[idx1], wspace=0.15, hspace=0)\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):\n",
    "            box3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "            ax = plt.Subplot(fig, box3[0])\n",
    "            fig.add_subplot(ax)\n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            axd[plot_key]=ax\n",
    "            \n",
    "    for idx1, stage in enumerate([\"noretraining\", \"classifiermask\", \"retraining\"]):\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):    \n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            \n",
    "            if stage==\"noretraining\":\n",
    "                accuracy_summarized=pd.DataFrame(noretraining_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))                        \n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "                \n",
    "            elif stage==\"classifiermask\":\n",
    "                accuracy_summarized=pd.DataFrame(classifiermasked_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])                \n",
    "                \n",
    "            elif stage==\"retraining\":\n",
    "                retraining_backbone=data_loaded_all[\"9_retraining\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if 0 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[0]=1/10\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[0]=0.99439632892608\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     accuracy_data[0]=1/2\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise\n",
    "                                \n",
    "                        if 196 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[196]=0.99439632892608\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[196]=1/10\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     pass\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise   \n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            if stage==\"noretraining\":\n",
    "                axd[plot_key].set_title(f\"Fine-tuned Classifier ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "            elif stage==\"classifiermask\":\n",
    "                axd[plot_key].set_title(f\"Separate Evaluator ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Random-mask classifier)\")\n",
    "            elif stage==\"retraining\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Retraining)\")\n",
    "            else:\n",
    "                raise\n",
    "#             elif stage==\"classifiermask\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "                    \n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise                            \n",
    "#             elif stage==\"retraining\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "#                     axd[plot_key].set_title(\"Insertion\")\n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise            \n",
    "            \n",
    "            axd[plot_key].set_ylim(0, 1.05)\n",
    "            axd[plot_key].set_xlim(-2, 200)            \n",
    "\n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)  \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "\n",
    "            #if insert_delete==\"insert\":\n",
    "            axd[plot_key].set_ylabel(\"Accuracy\")#, labelpad=-10)\n",
    "\n",
    "            if insert_delete==\"insert\":\n",
    "                axd[plot_key].set_xlabel(\"# of Inserted Patches\", labelpad=5)\n",
    "            elif insert_delete==\"delete\":\n",
    "                axd[plot_key].set_xlabel(\"# of Deleted Patches\", labelpad=5)\n",
    "            else: \n",
    "                raise\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=5, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=15,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, -0.0))    \n",
    "    \n",
    "    \n",
    "    fig.show()\n",
    "    \n",
    "    fig.savefig(f\"results/plots/ROAR_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')          \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0387a4b",
   "metadata": {},
   "source": [
    "# ROAR Plotting-four metrics-insertion/deletion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d27c0c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])   \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 15\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(18, 5*4))\n",
    "    box1 = gridspec.GridSpec(4, 1, wspace=0, hspace=0.3)\n",
    "    \n",
    "    axd={}\n",
    "    for idx1, stage in enumerate([\"noretraining\", \"classifiermask\", \"retraining\", \"retrainingnopos\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=box1[idx1], wspace=0.15, hspace=0)\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):\n",
    "            box3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "            ax = plt.Subplot(fig, box3[0])\n",
    "            fig.add_subplot(ax)\n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            axd[plot_key]=ax\n",
    "            \n",
    "    for idx1, stage in enumerate([\"noretraining\", \"classifiermask\", \"retraining\", \"retrainingnopos\"]):\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):    \n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            \n",
    "            if stage==\"noretraining\":\n",
    "                accuracy_summarized=pd.DataFrame(noretraining_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))                        \n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "                \n",
    "            elif stage==\"classifiermask\":\n",
    "                accuracy_summarized=pd.DataFrame(classifiermasked_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])                \n",
    "                \n",
    "            elif stage==\"retraining\":\n",
    "                retraining_backbone=data_loaded_all[\"9_retraining\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if backbone_type==\"vit_base_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.99439632892608\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.99439632892608\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise  \n",
    "                        elif backbone_type==\"vit_small_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.9934\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.9934\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise                          \n",
    "                        else:\n",
    "                            raise\n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            elif stage==\"retrainingnopos\":\n",
    "                retraining_backbone=data_loaded_all[\"10_retrainingnopos\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        \n",
    "                        \n",
    "                        if backbone_type==\"vit_base_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.9363\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.9363\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise   \n",
    "                        elif backbone_type==\"vit_small_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.907\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.907\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise \n",
    "                        else:\n",
    "                            raise\n",
    "                            \n",
    "                            \n",
    "                            \n",
    "                            \n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            if stage==\"noretraining\":\n",
    "                axd[plot_key].set_title(f\"Fine-tuned Classifier ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "            elif stage==\"classifiermask\":\n",
    "                axd[plot_key].set_title(f\"Separate Evaluator ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Random-mask classifier)\")\n",
    "            elif stage==\"retraining\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/ Positional Embed. ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Retraining)\")\n",
    "            elif stage==\"retrainingnopos\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/o Positional Embed. ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")                \n",
    "            else:\n",
    "                raise\n",
    "#             elif stage==\"classifiermask\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "                    \n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise                            \n",
    "#             elif stage==\"retraining\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "#                     axd[plot_key].set_title(\"Insertion\")\n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise            \n",
    "            \n",
    "            axd[plot_key].set_ylim(0, 1.05)\n",
    "            axd[plot_key].set_xlim(-2, 200)            \n",
    "\n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)  \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "\n",
    "            #if insert_delete==\"insert\":\n",
    "            axd[plot_key].set_ylabel(\"Accuracy\")#, labelpad=-10)\n",
    "\n",
    "            if insert_delete==\"insert\":\n",
    "                axd[plot_key].set_xlabel(\"# of Inserted Patches\", labelpad=5)\n",
    "            elif insert_delete==\"delete\":\n",
    "                axd[plot_key].set_xlabel(\"# of Deleted Patches\", labelpad=5)\n",
    "            else: \n",
    "                raise\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=5, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=15,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, 0.035))    \n",
    "    \n",
    "    \n",
    "    fig.show()\n",
    "    \n",
    "    fig.savefig(f\"results/plots/ROAR_all_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')          \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da7b0733",
   "metadata": {},
   "source": [
    "# ROAR Plotting-four metrics-deletion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7815fec",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])   \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 15\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(18, 5*2))\n",
    "    box1 = gridspec.GridSpec(2, 1, wspace=0, hspace=0.3)\n",
    "    \n",
    "    axd={}\n",
    "    insert_delete=\"delete\"\n",
    "    for idx1, stage_row in enumerate([[\"noretraining\", \"classifiermask\"], [\"retraining\", \"retrainingnopos\"]]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=box1[idx1], wspace=0.15, hspace=0)\n",
    "        for idx2, stage in enumerate(stage_row):\n",
    "            box3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "            ax = plt.Subplot(fig, box3[0])\n",
    "            fig.add_subplot(ax)\n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            axd[plot_key]=ax\n",
    "            \n",
    "    for idx1, stage in enumerate([\"noretraining\", \"classifiermask\", \"retraining\", \"retrainingnopos\"]):\n",
    "        for idx2, insert_delete in enumerate([\"delete\"]):    \n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            \n",
    "            if stage==\"noretraining\":\n",
    "                accuracy_summarized=pd.DataFrame(noretraining_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))                        \n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "                \n",
    "            elif stage==\"classifiermask\":\n",
    "                accuracy_summarized=pd.DataFrame(classifiermasked_result_dict[dataset_name][backbone_type])\\\n",
    "                .groupby([\"metric_mode\", \"explanation_method\"])[\"accuracy\"].apply(lambda x: np.mean(x, axis=0))\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        axd[plot_key].plot(np.arange(0,196+1),\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])                \n",
    "                \n",
    "            elif stage==\"retraining\":\n",
    "                retraining_backbone=data_loaded_all[\"9_retraining\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if backbone_type==\"vit_base_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.99439632892608\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.99439632892608\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise  \n",
    "                        elif backbone_type==\"vit_small_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.9934\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.9934\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise                          \n",
    "                        else:\n",
    "                            raise\n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            elif stage==\"retrainingnopos\":\n",
    "                retraining_backbone=data_loaded_all[\"10_retrainingnopos\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        \n",
    "                        \n",
    "                        if backbone_type==\"vit_base_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.9363\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.9363\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise   \n",
    "                        elif backbone_type==\"vit_small_patch16_224\":\n",
    "                            if 0 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[0]=1/10\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[0]=0.907\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     accuracy_data[0]=1/2\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise\n",
    "\n",
    "                            if 196 not in accuracy_data.index.values:\n",
    "                                if dataset_name==\"ImageNette\":\n",
    "                                    if insert_delete==\"insert\":\n",
    "                                        accuracy_data[196]=0.907\n",
    "                                    elif insert_delete==\"delete\":\n",
    "                                        accuracy_data[196]=1/10\n",
    "    #                             elif dataset_name==\"MURA\":\n",
    "    #                                 if insert_delete==\"insert\":\n",
    "    #                                     pass\n",
    "    #                                 elif insert_delete==\"delete\":\n",
    "    #                                     pass\n",
    "                                else:\n",
    "                                    raise \n",
    "                        else:\n",
    "                            raise\n",
    "                            \n",
    "                            \n",
    "                            \n",
    "                            \n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "\n",
    "            if stage==\"noretraining\":\n",
    "                axd[plot_key].set_title(f\"Fine-tuned Classifier\")# ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "            elif stage==\"classifiermask\":\n",
    "                axd[plot_key].set_title(f\"Separate Evaluator\")# ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Random-mask classifier)\")\n",
    "            elif stage==\"retraining\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/ Positional Embed.\")# ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "                #axd[plot_key].set_title(f\"{'Insertion' if insert_delete=='insert' else 'Deletion'} (Retraining)\")\n",
    "            elif stage==\"retrainingnopos\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/o Positional Embed.\")# ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")                \n",
    "            else:\n",
    "                raise\n",
    "#             elif stage==\"classifiermask\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "                    \n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise                            \n",
    "#             elif stage==\"retraining\":\n",
    "#                 if insert_delete==\"insert\":\n",
    "#                     axd[plot_key].set_title(\"Insertion\")\n",
    "#                 elif insert_delete==\"delete\":\n",
    "#                     axd[plot_key].set_title(\"Deletion\")\n",
    "#                 else:\n",
    "#                     raise            \n",
    "            \n",
    "            axd[plot_key].set_ylim(0, 1.05)\n",
    "            axd[plot_key].set_xlim(-2, 200)            \n",
    "\n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)  \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "\n",
    "            #if insert_delete==\"insert\":\n",
    "            axd[plot_key].set_ylabel(\"Accuracy\")#, labelpad=-10)\n",
    "\n",
    "            if insert_delete==\"insert\":\n",
    "                axd[plot_key].set_xlabel(\"# of Inserted Patches\", labelpad=5)\n",
    "            elif insert_delete==\"delete\":\n",
    "                axd[plot_key].set_xlabel(\"# of Deleted Patches\", labelpad=5)\n",
    "            else: \n",
    "                raise\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=5, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=15,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, -0.06))    \n",
    "    \n",
    "    \n",
    "    fig.show()\n",
    "    \n",
    "    fig.savefig(f\"results/plots/ROAR_all_deletion_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_deletion_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_deletion_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_all_deletion_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')          \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73b8d210",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loaded_all.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11799a46",
   "metadata": {},
   "source": [
    "# Plot ROAR-retraining nopos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2319544",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])   \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 15\n",
    "    \n",
    "    \n",
    "    fig = plt.figure(figsize=(18, 5*2))\n",
    "    box1 = gridspec.GridSpec(2, 1, wspace=0, hspace=0.3)\n",
    "    \n",
    "    axd={}\n",
    "    for idx1, stage in enumerate([\"retraining\", \"retrainingnopos\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=box1[idx1], wspace=0.15, hspace=0)\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):\n",
    "            box3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "            ax = plt.Subplot(fig, box3[0])\n",
    "            fig.add_subplot(ax)\n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            axd[plot_key]=ax\n",
    "            \n",
    "    for idx1, stage in enumerate([\"retraining\", \"retrainingnopos\"]):\n",
    "        for idx2, insert_delete in enumerate([\"insert\", \"delete\"]):    \n",
    "            \n",
    "            plot_key=f\"{stage}_{insert_delete}\"\n",
    "            if stage==\"retraining\":\n",
    "                retraining_backbone=data_loaded_all[\"9_retraining\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if 0 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[0]=1/10\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[0]=0.99439632892608\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     accuracy_data[0]=1/2\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise\n",
    "                                \n",
    "                        if 196 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[196]=0.99439632892608\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[196]=1/10\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     pass\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise   \n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "                        \n",
    "                        \n",
    "            elif stage==\"retrainingnopos\":\n",
    "                retraining_backbone=data_loaded_all[\"10_retrainingnopos\"][dataset_name][backbone_type]\n",
    "                retraining_backbone[\"explanation_method\"]=retraining_backbone[\"explanation_location_train\"].map(lambda x: '_'.join(x.split('/')[8].replace('_train.pickle', '').split('_')[4:]) if x is not None else None)\n",
    "                retraining_backbone[\"metric_mode\"]=retraining_backbone[\"explanation_mask_ascending_train\"].map(lambda x: \"insert\" if x==True else \"delete\")\n",
    "                retraining_backbone[\"num_stage\"]=retraining_backbone.apply(lambda row: 196-row[\"explanation_mask_amount_train\"] if row[\"metric_mode\"]==\"insert\" else row[\"explanation_mask_amount_train\"], axis=1)\n",
    "                accuracy_summarized=retraining_backbone.groupby([\"metric_mode\", \"explanation_method\", \"num_stage\"]).mean()[\"accuracy\"]\n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                        accuracy_data=accuracy_summarized.loc[insert_delete].loc[explanation_method]\n",
    "                        if 0 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[0]=1/10\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[0]=0.9363\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     accuracy_data[0]=1/2\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise\n",
    "                                \n",
    "                        if 196 not in accuracy_data.index.values:\n",
    "                            if dataset_name==\"ImageNette\":\n",
    "                                if insert_delete==\"insert\":\n",
    "                                    accuracy_data[196]=0.9363\n",
    "                                elif insert_delete==\"delete\":\n",
    "                                    accuracy_data[196]=1/10\n",
    "#                             elif dataset_name==\"MURA\":\n",
    "#                                 if insert_delete==\"insert\":\n",
    "#                                     pass\n",
    "#                                 elif insert_delete==\"delete\":\n",
    "#                                     pass\n",
    "                            else:\n",
    "                                raise   \n",
    "                        #print(accuracy_data)\n",
    "                        accuracy_data=accuracy_data.sort_index()                           \n",
    "                        \n",
    "                        axd[plot_key].plot(accuracy_data.index,\n",
    "                                           accuracy_data,\n",
    "                                           linewidth=2,\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8])\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "            if stage==\"retraining\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/ Positional Embed. ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "\n",
    "            elif stage==\"retrainingnopos\":\n",
    "                axd[plot_key].set_title(f\"Masked Retraining w/o Positional Embed. ({'Insertion' if insert_delete=='insert' else 'Deletion'})\")\n",
    "            else:\n",
    "                raise        \n",
    "            \n",
    "            axd[plot_key].set_ylim(0, 1.05)\n",
    "            axd[plot_key].set_xlim(-2, 200)            \n",
    "\n",
    "            axd[plot_key].spines['right'].set_visible(False)\n",
    "            axd[plot_key].spines['top'].set_visible(False)\n",
    "\n",
    "            axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "            axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "            axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)   \n",
    "            \n",
    "            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "            axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.6)\n",
    "            axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.2)  \n",
    "            \n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2) \n",
    "\n",
    "            #if insert_delete==\"insert\":\n",
    "            axd[plot_key].set_ylabel(\"Accuracy\")#, labelpad=-10)\n",
    "\n",
    "            if insert_delete==\"insert\":\n",
    "                axd[plot_key].set_xlabel(\"# of Inserted Patches\", labelpad=5)\n",
    "            elif insert_delete==\"delete\":\n",
    "                axd[plot_key].set_xlabel(\"# of Deleted Patches\", labelpad=5)\n",
    "            else: \n",
    "                raise\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    legend_elements.insert(2, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    legend_elements.insert(13, Line2D([0], [0], linewidth=0))\n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=5, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=15,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, -0.06))    \n",
    "    \n",
    "    fig.show()\n",
    "    \n",
    "    fig.savefig(f\"results/plots/ROAR_nopos_{backbone_type}_{dataset_name}.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_nopos_{backbone_type}_{dataset_name}.svg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_nopos_{backbone_type}_{dataset_name}.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/ROAR_nopos_{backbone_type}_{dataset_name}.pdf\", bbox_inches='tight')          \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d048998b",
   "metadata": {},
   "source": [
    "# Running time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa25c950",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "elaspedtime_table_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    elaspedtime_table_dict.setdefault(backbone_type, {})\n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        elaspedtime_table_dict[backbone_type].setdefault(dataset_name, {})\n",
    "        print(f'{dataset_name}   {backbone_type}')    \n",
    "    \n",
    "        elaspedtime_save_dict_backbone=data_loaded_all[\"8_elapsedtime\"][dataset_name][backbone_type]\n",
    "        \n",
    "        elaspedtime_dict_list=[]\n",
    "        for explanation_method, elapsedtime_save_dict_backbone_method in elaspedtime_save_dict_backbone.items():\n",
    "            if len(elapsedtime_save_dict_backbone_method)==0:\n",
    "                print(explanation_method, 'not exist')\n",
    "                continue\n",
    "            for path, data in elapsedtime_save_dict_backbone_method.items():\n",
    "                elaspedtime_dict_list.append({\"method\": explanation_method,\n",
    "                                            \"time\": 1000*data[\"time\"]})\n",
    "            \n",
    "        elaspedtime_table_mean=pd.DataFrame(elaspedtime_dict_list).groupby(\"method\")[[\"time\"]].mean()\n",
    "        elaspedtime_table_mean=elaspedtime_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        elaspedtime_table_std=pd.DataFrame(elaspedtime_dict_list).groupby(\"method\")[[\"time\"]].apply(lambda x: 1.96*x.std()/((len(x))**(0.5)))\n",
    "        elaspedtime_table_std=elaspedtime_table_std.add_suffix(\"_std\")\n",
    "        \n",
    "        elaspedtime_table=pd.concat([elaspedtime_table_mean, elaspedtime_table_std], axis=1)            \n",
    "        \n",
    "        elapsedtime_table_select=elaspedtime_table.loc[[j for i in explanation_method_main for j in i]]\n",
    "        elapsedtime_table_select.index=elapsedtime_table_select.index.map(lambda x: explanation_method_mapper(x, \"main\"))\n",
    "        \n",
    "        elapsedtime_table_select_format=elapsedtime_table_select.apply(lambda x: '{:.1f}'.format(x[f\"time_mean\"]), axis=1).to_frame()\n",
    "        elapsedtime_table_select_format.columns=[\"Time (msec)\"]\n",
    "        elapsedtime_table_select_format.index.name=None\n",
    "        print(elapsedtime_table_select_format)\n",
    "        \n",
    "        elaspedtime_table_dict[backbone_type][dataset_name]=elapsedtime_table_select_format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b21ffb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def df_to_latex(table_df, caption, explanation_methods_category_list):\n",
    "    latex_output=table_df.style.to_latex(hrules=True, \n",
    "                                         multicol_align='c')\n",
    "    for explanation_methods_category in explanation_methods_category_list[1:]:\n",
    "        latex_output=latex_output.replace('\\n'+explanation_methods_category[0], \n",
    "                                          '\\\\midrule\\n'+explanation_methods_category[0], 1)\n",
    "    latex_output_split=latex_output.split('\\n')\n",
    "\n",
    "    latex_output='\\n'.join(latex_output_split)\n",
    "    latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "    latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "    latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "    latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "    latex_output=latex_output.replace(\"bfe\",\"}\")    \n",
    "    latex_output=latex_output.replace(\"grays\",\"\\\\textcolor{gray}{\")\n",
    "    latex_output=latex_output.replace(\"graye\",\"}\")        \n",
    "    latex_output='% \\\\begin{scriptsize}\\n' + latex_output + '% \\\\end{scriptsize}\\n'\n",
    "    latex_output='% \\\\begin{small}\\n' + latex_output + '% \\\\end{small}\\n'\n",
    "    latex_output='\\\\begin{center}\\n' + latex_output + '\\\\end{center}\\n'\n",
    "    latex_output='\\\\vskip 0.01in\\n' + latex_output\n",
    "    latex_output='\\\\caption{{{}}}\\n'.format(caption) + latex_output\n",
    "    latex_output='\\\\begin{table}\\n' + latex_output + '\\\\end{table}\\n'\n",
    "    latex_output=latex_output.replace('{lll}','{lrr}')\n",
    "    return latex_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd33525a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    elapsedtime_table_select_format=pd.concat([elaspedtime_table_dict[backbone_type][\"ImageNette\"], elaspedtime_table_dict[backbone_type][\"MURA\"]], axis=1)\n",
    "\n",
    "    explanation_methods_category_list=[[explanation_method_mapper(j, \"main\") for j in i] for i in explanation_method_main_random]\n",
    "\n",
    "    caption=f'Running time comparison'\n",
    "\n",
    "    print(df_to_latex(elapsedtime_table_select_format, \n",
    "                caption=caption,\n",
    "                explanation_methods_category_list=explanation_methods_category_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bfe6f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "def df_to_latex(table_df, last_col, caption):\n",
    "    latex_output=table_df.style.to_latex(hrules=True, \n",
    "                                         multicol_align='c')\n",
    "\n",
    "    #latex_output=latex_output.replace('\\n'+last_col,'\\\\midrule\\n'+last_col)\n",
    "    \n",
    "    latex_output_split=latex_output.split('\\n')\n",
    "    \n",
    "    latex_output='\\n'.join(latex_output_split)\n",
    "    latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "    latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "    latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "    latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "    latex_output=latex_output.replace(\"bfe\",\"}\")    \n",
    "    latex_output=latex_output.replace(\"grays\",\"\\\\textcolor{gray}{\")\n",
    "    latex_output=latex_output.replace(\"graye\",\"}\")        \n",
    "    latex_output='% \\\\begin{scriptsize}\\n' + latex_output + '% \\\\end{scriptsize}\\n'\n",
    "    latex_output='% \\\\begin{small}\\n' + latex_output + '% \\\\end{small}\\n'\n",
    "    latex_output='\\\\begin{center}\\n' + latex_output + '\\\\end{center}\\n'\n",
    "    latex_output='\\\\vskip 0.01in\\n' + latex_output\n",
    "    latex_output='\\\\caption{{{}}}\\n'.format(caption) + latex_output\n",
    "    latex_output='\\\\begin{table}\\n' + latex_output + '\\\\end{table}\\n'\n",
    "    latex_output=latex_output.replace('{lllll}','{ccccc}')\n",
    "    return latex_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07137579",
   "metadata": {},
   "outputs": [],
   "source": [
    "running_time_table = pd.DataFrame([\n",
    "                                  {\"Freeze backbone\": False,\n",
    "                                   \"# of Attn. blocks\": 0,\n",
    "                                   \"Removal method\": \"Masking (fine-tuned)\",\n",
    "                                   \"Loss\": 4.351},     \n",
    "                                  {\"Freeze backbone\": True,\n",
    "                                   \"# of Attn. blocks\": 0,\n",
    "                                   \"Removal method\": \"Masking (fine-tuned)\",\n",
    "                                   \"Loss\": 4.351},     \n",
    "    \n",
    "    \n",
    "                                  {\"Freeze backbone\": False,\n",
    "                                   \"# of Attn. blocks\": 1,\n",
    "                                   \"Removal method\": \"Masking (fine-tuned)\",\n",
    "                                   \"Loss\": 4.318},\n",
    "                                  {\"Freeze backbone\": True,\n",
    "                                   \"# of Attn. blocks\": 1,\n",
    "                                   \"Removal method\": \"Masking (fine-tuned)\",\n",
    "                                   \"Loss\": 4.339},\n",
    "                                  {\"Freeze backbone\": False,\n",
    "                                   \"# of Attn. blocks\": 1,\n",
    "                                   \"Removal method\": \"Mask token (fine-tuned)\",\n",
    "                                   \"Loss\": 4.388},             \n",
    "                                 ])\n",
    "\n",
    "running_time_table=running_time_table[['# of Attn. blocks', 'Freeze backbone', 'Removal method', 'Loss']]\n",
    "running_time_table.index=running_time_table.index.map(lambda x: \"\\#\"+str(x+1))\n",
    "running_time_table=running_time_table.T\n",
    "running_time_table.index=running_time_table.index.str.replace(\"# of Attn. blocks\", \"Extra attention block\")\n",
    "running_time_table.index=running_time_table.index.str.replace(\"#\", \"\\#\")\n",
    "running_time_table.loc[\"Extra attention block\"]=running_time_table.loc[\"Extra attention block\"].astype(bool).astype(str)\n",
    "\n",
    "running_time_table.loc[\"Freeze backbone\"]=running_time_table.loc[\"Freeze backbone\"].astype(str)\n",
    "\n",
    "running_time_table.loc[\"Loss\"]=running_time_table.loc[\"Loss\"].apply(lambda x: \"bfs{:.3f}bfe\".format(x) if x<=running_time_table.loc[\"Loss\"].min() else  \"{:.3f}\".format(x))\n",
    "running_time_table.index=running_time_table.index.str.replace(\"Loss\", \"Test loss\")\n",
    "\n",
    "running_time_table=running_time_table.T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8491ac1c",
   "metadata": {},
   "source": [
    "# Qualitative result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20dbeca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vit_shapley.config import config, dataset_ImageNette, dataset_MURA\n",
    "\n",
    "from vit_shapley.datamodules.ImageNette_datamodule import ImageNetteDataModule\n",
    "from vit_shapley.datamodules.MURA_datamodule import MURADataModule\n",
    "\n",
    "def set_datamodule(datasets,\n",
    "                   dataset_location,\n",
    "                   explanation_location_train,\n",
    "                   explanation_mask_amount_train,\n",
    "                   explanation_mask_ascending_train,\n",
    "                   \n",
    "                   explanation_location_val,\n",
    "                   explanation_mask_amount_val,\n",
    "                   explanation_mask_ascending_val,                   \n",
    "                   \n",
    "                   explanation_location_test,\n",
    "                   explanation_mask_amount_test,\n",
    "                   explanation_mask_ascending_test,                   \n",
    "                   \n",
    "                   transforms_train,\n",
    "                   transforms_val,\n",
    "                   transforms_test,\n",
    "                   num_workers,\n",
    "                   per_gpu_batch_size,\n",
    "                   test_data_split):\n",
    "    dataset_parameters = {\n",
    "        \"dataset_location\": dataset_location,\n",
    "        \"explanation_location_train\": explanation_location_train,\n",
    "        \"explanation_mask_amount_train\": explanation_mask_amount_train,\n",
    "        \"explanation_mask_ascending_train\": explanation_mask_ascending_train,\n",
    "        \n",
    "        \"explanation_location_val\": explanation_location_val,\n",
    "        \"explanation_mask_amount_val\": explanation_mask_amount_val,\n",
    "        \"explanation_mask_ascending_val\": explanation_mask_ascending_val,\n",
    "        \n",
    "        \"explanation_location_test\": explanation_location_test,\n",
    "        \"explanation_mask_amount_test\": explanation_mask_amount_test,\n",
    "        \"explanation_mask_ascending_test\": explanation_mask_ascending_test,        \n",
    "        \n",
    "        \"transforms_train\": transforms_train,\n",
    "        \"transforms_val\": transforms_val,\n",
    "        \"transforms_test\": transforms_test,\n",
    "        \"num_workers\": num_workers,\n",
    "        \"per_gpu_batch_size\": per_gpu_batch_size,\n",
    "        \"test_data_split\": test_data_split\n",
    "    }\n",
    "\n",
    "    if datasets == \"MURA\":\n",
    "        datamodule = MURADataModule(**dataset_parameters)\n",
    "    elif datasets == \"ImageNette\":\n",
    "        datamodule = ImageNetteDataModule(**dataset_parameters)\n",
    "    else:\n",
    "        ValueError(\"Invalid 'datasets' configuration\")\n",
    "    return datamodule\n",
    "\n",
    "\n",
    "dataset_dict={}\n",
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    \n",
    "    _config=config()\n",
    "    if dataset_name==\"ImageNette\":\n",
    "        _config.update(dataset_ImageNette())\n",
    "    elif dataset_name==\"MURA\":\n",
    "        _config.update(dataset_MURA())\n",
    "\n",
    "    datamodule = set_datamodule(datasets=_config[\"datasets\"],\n",
    "                                dataset_location=_config[\"dataset_location\"],\n",
    "\n",
    "                                explanation_location_train=_config[\"explanation_location_train\"],\n",
    "                                explanation_mask_amount_train=_config[\"explanation_mask_amount_train\"],\n",
    "                                explanation_mask_ascending_train=_config[\"explanation_mask_ascending_train\"],\n",
    "\n",
    "                                explanation_location_val=_config[\"explanation_location_val\"],\n",
    "                                explanation_mask_amount_val=_config[\"explanation_mask_amount_val\"],\n",
    "                                explanation_mask_ascending_val=_config[\"explanation_mask_ascending_val\"],\n",
    "\n",
    "                                explanation_location_test=_config[\"explanation_location_test\"],\n",
    "                                explanation_mask_amount_test=_config[\"explanation_mask_amount_test\"],\n",
    "                                explanation_mask_ascending_test=_config[\"explanation_mask_ascending_test\"],                            \n",
    "\n",
    "                                transforms_train=_config[\"transforms_train\"],\n",
    "                                transforms_val=_config[\"transforms_val\"],\n",
    "                                transforms_test=_config[\"transforms_test\"],\n",
    "                                num_workers=_config[\"num_workers\"],\n",
    "                                per_gpu_batch_size=_config[\"per_gpu_batch_size\"],\n",
    "                                test_data_split=_config[\"test_data_split\"])\n",
    "    \n",
    "    datamodule.set_train_dataset()\n",
    "    datamodule.set_val_dataset()\n",
    "    datamodule.set_test_dataset()\n",
    "    dataset_dict[dataset_name]=datamodule.test_dataset\n",
    "    \n",
    "    \n",
    "label_dict={}\n",
    "\n",
    "label_dict[\"ImageNette\"]=['Cassette player', \n",
    "                          'Garbage truck', \n",
    "                          'Tench', \n",
    "                          'English springer', \n",
    "                          'Church', \n",
    "                          'Parachute', \n",
    "                          'French horn', \n",
    "                          'Chain saw', \n",
    "                          'Golf ball', \n",
    "                          'Gas pump']\n",
    "label_dict[\"MURA\"]=[\"Normal\", \"Abnormal\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4ce07ca",
   "metadata": {},
   "source": [
    "# Main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96cc781d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_tick(x):\n",
    "    print('finding',x)\n",
    "    if x>0.2 and x<0.25:\n",
    "        return 0.2, 0.1\n",
    "    elif x>0.01 and x<0.015:\n",
    "        return 0.015, 0.005\n",
    "    elif x>0.004 and x<0.005:\n",
    "        return 0.004\n",
    "    else:\n",
    "        print('Not defined', x)\n",
    "        raise\n",
    "        \n",
    "def custom_format(x, pos=None):\n",
    "    if x==0:\n",
    "        return 0\n",
    "    else:\n",
    "        return \"{:.1e}\".format(x)\n",
    "        #return 1        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105ca2be",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dict[\"MURA\"][219]#[\"images\"].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a74e652",
   "metadata": {},
   "source": [
    "# Figure 1-4 rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "191775b6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# dataset_idx_dict={\"ImageNette\": {\"target\": 43,\n",
    "#                             \"non-target\":43},\n",
    "#              \"MURA\": {\"target\": 8,\n",
    "#                       \"non-target\": 9},\n",
    "#             }\n",
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\"\"\"\n",
    "ImageNette\n",
    "97 golfball\n",
    "7 gaspump\n",
    "11 gaspump\n",
    "147 gas pump\n",
    "\n",
    "MURA\n",
    "target 15\n",
    "nontarget 17\n",
    "\"\"\"\n",
    "\n",
    "dataset_idx_dict={\"ImageNette\": {\"target\": 7, #gas pump\n",
    "                            \"non-target\": 7},\n",
    "                  \"MURA\": {\"target\": 15,\n",
    "                           \"non-target\": 219},\n",
    "                 }\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                               Paired[12][3],\n",
    "                                                                               Paired[12][5],\n",
    "                                                                               ]])\n",
    "    \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 12\n",
    "\n",
    "    fig = plt.figure(figsize=(3*(1+1+0.1+0.01+1.7+0.01+1.7), 3*4))\n",
    "    box1 = gridspec.GridSpec(1, 7, wspace=0.3, hspace=0, width_ratios=[1, 1, 0.1, 0.01, 1.7, 0.01, 1.7])\n",
    "\n",
    "    axd={}\n",
    "    for idx1, plot_type in enumerate([\"image\", \"map\", \"colorbar\", \"empty1\", \"insert\", \"empty2\", \"delete\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                        subplot_spec=box1[idx1], wspace=0, hspace=0.15)\n",
    "        for idx2, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "            if plot_type==\"image\" and dataset_name==\"ImageNette\":\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.1)\n",
    "                ax = plt.Subplot(fig, box3[0])\n",
    "                fig.add_subplot(ax)\n",
    "\n",
    "                plot_key=f\"{dataset_name}_target_{plot_type}\"\n",
    "                axd[plot_key]=ax            \n",
    "                plot_key=f\"{dataset_name}_non-target_{plot_type}\"\n",
    "                axd[plot_key]=ax                            \n",
    "            else:\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.3)\n",
    "                for idx3, target_non_target in enumerate([\"target\", \"non-target\"]):\n",
    "                    ax = plt.Subplot(fig, box3[idx3])\n",
    "                    fig.add_subplot(ax)\n",
    "\n",
    "                    plot_key=f\"{dataset_name}_{target_non_target}_{plot_type}\"\n",
    "                    axd[plot_key]=ax\n",
    "    \n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]        \n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type]        \n",
    "        \n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):    \n",
    "            dataset_item=dataset_dict[dataset_name][dataset_idx_dict[dataset_name][target_non_target]]\n",
    "\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            #image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx=0\n",
    "            else:\n",
    "                #target_idx=label\n",
    "                if target_non_target==\"target\":\n",
    "                    class_idx=label\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    insert_data=insertdelete_save_dict_backbone[\"ours\"][adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys()))]['insert']\n",
    "                    class_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                    class_idx.remove(label)\n",
    "                    class_idx=class_idx[0]\n",
    "                else:\n",
    "                    raise\n",
    "            #print(class_idx)\n",
    "\n",
    "            # Image\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_image\"\n",
    "\n",
    "            axd[plot_key].imshow(image_unnormlized)\n",
    "\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)\n",
    "                \n",
    "            if dataset_name==\"ImageNette\":\n",
    "                #axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][np.argmax(classifier_prob)]}\", pad=7)\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\", pad=7)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                #axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][int(classifier_prob>0.5)]}\", pad=7)\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\", pad=7)\n",
    "            else:\n",
    "                raise            \n",
    "                \n",
    "            # Map\n",
    "            explanation=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "            explanation_class=explanation[class_idx]\n",
    "\n",
    "            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "    \n",
    "            #colormap_max, colormap_tick = get_max_tick(np.max(np.abs(explanation_class_expanded)))\n",
    "            #colormap_max=float(\"{:.1e}\".format(np.max(np.abs(explanation_class_expanded))))\n",
    "            colormap_max=np.max(np.abs(explanation_class_expanded))\n",
    "            colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "            #colormap_max_base=int(np.round(-np.log10(colormap_max)-np.log10(0.5)+0.5))\n",
    "            print(colormap_max_base)\n",
    "        \n",
    "            explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "            cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "            explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "            explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "            print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "            \n",
    "            image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "            cmap=cm.get_cmap('Greys', 1000)#cm.get_cmap('Greys', 1000)\n",
    "            image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "            image_unnormlized_normalized[:,:,3]=0.5\n",
    "            \n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_map\"\n",
    "            \n",
    "            axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "            #axd[plot_key].imshow(image_unnormlized, alpha=0.5)\n",
    "            axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)\n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                axd[plot_key].set_title(\"Abnormal\", pad=10)\n",
    "                pass\n",
    "            else:\n",
    "                raise\n",
    "                \n",
    "            # Colorbar\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_colorbar\"\n",
    "\n",
    "            axd[plot_key+\"_inner\"] = inset_axes(axd[plot_key],\n",
    "                                                width=\"80%\",  # width = 5% of parent_bbox width\n",
    "                                                height=\"90%\",\n",
    "                                                loc='lower center',\n",
    "                                                bbox_to_anchor=(-0.4,0,1,1),\n",
    "                                                bbox_transform=axd[plot_key].transAxes)\n",
    "            \n",
    "            cbar=fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=-colormap_max*10**(colormap_max_base), vmax=colormap_max*10**(colormap_max_base)), \n",
    "                                                    cmap=sns.color_palette(\"icefire\", as_cmap=True)),\n",
    "                              cax=axd[plot_key+\"_inner\"],\n",
    "                              format=ticker.FuncFormatter(lambda x, pos: f'{int(x):2d}'),\n",
    "                         orientation='vertical')\n",
    "            cbar.outline.set_linewidth(0.3)\n",
    "            if colormap_max*10**(colormap_max_base)>5:\n",
    "                cbar.set_ticks([-5,0,5])\n",
    "            elif colormap_max*10**(colormap_max_base)>4:\n",
    "                cbar.set_ticks([-4, -2, 0, 2, 4])\n",
    "                cbar.set_ticklabels([-4, -2, 0, 2, 4])\n",
    "            elif colormap_max*10**(colormap_max_base)>2:\n",
    "                cbar.set_ticks([-2, -1, 0, 1, 2])\n",
    "            else:\n",
    "                pass\n",
    "            #cbar.set_major_formatter(ticker.StrMethodFormatter(\"{x:.3f}\"))\n",
    "            \n",
    "            axd[plot_key+\"_inner\"].set_title(f\" x$\\mathregular{{10^{{{-colormap_max_base}}}}}$\",\n",
    "                                    loc=\"left\",\n",
    "                                    size=13,\n",
    "                                   )\n",
    "            #axd[plot_key].tick_params(axis=\"both\", labelsize=10, width=0.4)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)         \n",
    "            \n",
    "#             \n",
    "#             formatter.set_scientific(True)\n",
    "#             formatter.set_powerlimits((-2, 2))\n",
    "\n",
    "#             x = np.exp(np.random.uniform(size=(10, 10)) * 10)\n",
    "#             sns.heatmap(x, cbar_kws={\"format\": formatter})            \n",
    "            \n",
    "            # insert delete\n",
    "            \n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                plot_key = f\"{dataset_name}_{target_non_target}_{insert_delete}\"\n",
    "                \n",
    "                for idx4, explanation_method in enumerate([j for i in explanation_method_qualitative_main for j in i]):\n",
    "                    insertdelete_value=insertdelete_save_dict_backbone[explanation_method][adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))]\n",
    "                    axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                       linestyle='-',\n",
    "                                       c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.4])\n",
    "                \n",
    "                axd[plot_key].set_xlim(-2, 200)                                   \n",
    "\n",
    "                axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "\n",
    "\n",
    "                if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                else:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                \n",
    "                \n",
    "                \n",
    "                axd[plot_key].spines['right'].set_visible(False)\n",
    "                axd[plot_key].spines['top'].set_visible(False)   \n",
    "\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(2)  \n",
    "\n",
    "                axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=0)\n",
    "                axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                \n",
    "                if dataset_name==\"MURA\" and target_non_target==\"non-target\" and insert_delete==\"insert\":\n",
    "                    axd[plot_key].set_xlabel('# of Inserted Patches')\n",
    "                if dataset_name==\"MURA\" and target_non_target==\"non-target\" and insert_delete==\"delete\":\n",
    "                    axd[plot_key].set_xlabel('# of Deleted Patches')                    \n",
    "                \n",
    "                axd[plot_key].set_ylabel(f\"Probability\", labelpad=3)\n",
    "                axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0)\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle='-',\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method, subset_mode=\"qualitative\"))\n",
    "                         for idx4, explanation_method in enumerate([j for i in explanation_method_qualitative_main for j in i])]    \n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=3, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=3,\n",
    "                fontsize=14,\n",
    "                loc='lower center', bbox_to_anchor=(0.67, 0.035))\n",
    "    fig.savefig(f\"results/plots/qualitative.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.pdf\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.svg\", bbox_inches='tight')    \n",
    "    \n",
    "    with plt.rc_context({'image.composite_image': False}):\n",
    "        fig.savefig(f\"results/plots/qualitative.pdf\", bbox_inches='tight')        \n",
    "    #fig.tight_layout()                "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "237311ce",
   "metadata": {},
   "source": [
    "# Figure 1-3 rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab4c0d80",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# dataset_idx_dict={\"ImageNette\": {\"target\": 43,\n",
    "#                             \"non-target\":43},\n",
    "#              \"MURA\": {\"target\": 8,\n",
    "#                       \"non-target\": 9},\n",
    "#             }\n",
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\"\"\"\n",
    "ImageNette\n",
    "97 golfball\n",
    "7 gaspump\n",
    "11 gaspump\n",
    "147 gas pump\n",
    "\n",
    "MURA\n",
    "target 15\n",
    "nontarget 17\n",
    "\"\"\"\n",
    "\n",
    "dataset_idx_dict={\"ImageNette\": {\"target\": 7, #gas pump\n",
    "                            \"non-target\": 7},\n",
    "                  \"MURA\": {\"target\": 15,\n",
    "                           \"non-target\": 219},\n",
    "                 }\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                               Paired[12][3],\n",
    "                                                                               Paired[12][5],\n",
    "                                                                               ]])\n",
    "    \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 12\n",
    "\n",
    "    fig = plt.figure(figsize=(3*(1+1+0.1+0.01+1.7+0.01+1.7), 3*4))\n",
    "    box1 = gridspec.GridSpec(1, 7, wspace=0.3, hspace=0, width_ratios=[1, 1, 0.1, 0.01, 1.7, 0.01, 1.7])\n",
    "\n",
    "    axd={}\n",
    "    for idx1, plot_type in enumerate([\"image\", \"map\", \"colorbar\", \"empty1\", \"insert\", \"empty2\", \"delete\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                        subplot_spec=box1[idx1], wspace=0, hspace=0.15)\n",
    "        for idx2, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "            if plot_type==\"image\" and dataset_name==\"ImageNette\":\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.1)\n",
    "                ax = plt.Subplot(fig, box3[0])\n",
    "                fig.add_subplot(ax)\n",
    "\n",
    "                plot_key=f\"{dataset_name}_target_{plot_type}\"\n",
    "                axd[plot_key]=ax            \n",
    "                plot_key=f\"{dataset_name}_non-target_{plot_type}\"\n",
    "                axd[plot_key]=ax                            \n",
    "            else:\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.3)\n",
    "                for idx3, target_non_target in enumerate([\"target\", \"non-target\"]):\n",
    "                    if dataset_name==\"MURA\" and target_non_target==\"non-target\":\n",
    "                        continue                     \n",
    "                    \n",
    "                    ax = plt.Subplot(fig, box3[idx3])\n",
    "                    fig.add_subplot(ax)\n",
    "\n",
    "                    plot_key=f\"{dataset_name}_{target_non_target}_{plot_type}\"\n",
    "                    axd[plot_key]=ax\n",
    "    \n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]        \n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type]        \n",
    "        \n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):    \n",
    "            \n",
    "            if dataset_name==\"MURA\" and target_non_target==\"non-target\":\n",
    "                continue            \n",
    "            \n",
    "            dataset_item=dataset_dict[dataset_name][dataset_idx_dict[dataset_name][target_non_target]]\n",
    "\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            #image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx=0\n",
    "            else:\n",
    "                #target_idx=label\n",
    "                if target_non_target==\"target\":\n",
    "                    class_idx=label\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    insert_data=insertdelete_save_dict_backbone[\"ours\"][adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys()))]['insert']\n",
    "                    class_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                    class_idx.remove(label)\n",
    "                    class_idx=class_idx[0]\n",
    "                else:\n",
    "                    raise\n",
    "            #print(class_idx)\n",
    "\n",
    "            # Image\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_image\"\n",
    "\n",
    "            axd[plot_key].imshow(image_unnormlized)\n",
    "\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)\n",
    "                \n",
    "            if dataset_name==\"ImageNette\":\n",
    "                #axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][np.argmax(classifier_prob)]}\", pad=7)\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\", pad=7)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                #axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][int(classifier_prob>0.5)]}\", pad=7)\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\", pad=7)\n",
    "            else:\n",
    "                raise            \n",
    "                \n",
    "            # Map\n",
    "            explanation=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "            explanation_class=explanation[class_idx]\n",
    "\n",
    "            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "    \n",
    "            #colormap_max, colormap_tick = get_max_tick(np.max(np.abs(explanation_class_expanded)))\n",
    "            #colormap_max=float(\"{:.1e}\".format(np.max(np.abs(explanation_class_expanded))))\n",
    "            colormap_max=np.max(np.abs(explanation_class_expanded))\n",
    "            colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "            #colormap_max_base=int(np.round(-np.log10(colormap_max)-np.log10(0.5)+0.5))\n",
    "            print(colormap_max_base)\n",
    "        \n",
    "            explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "            cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "            explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "            explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "            print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "            \n",
    "            image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "            cmap=cm.get_cmap('Greys', 1000)#cm.get_cmap('Greys', 1000)\n",
    "            image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "            image_unnormlized_normalized[:,:,3]=0.5\n",
    "            \n",
    "\n",
    "            \n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_map\"\n",
    "            \n",
    "            axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "            #axd[plot_key].imshow(image_unnormlized, alpha=0.5)\n",
    "            axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)\n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                axd[plot_key].set_title(\"Abnormal\", pad=10)\n",
    "                pass\n",
    "            else:\n",
    "                raise\n",
    "                \n",
    "            # Colorbar\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_colorbar\"\n",
    "\n",
    "            axd[plot_key+\"_inner\"] = inset_axes(axd[plot_key],\n",
    "                                                width=\"80%\",  # width = 5% of parent_bbox width\n",
    "                                                height=\"90%\",\n",
    "                                                loc='lower center',\n",
    "                                                bbox_to_anchor=(-0.4,0,1,1),\n",
    "                                                bbox_transform=axd[plot_key].transAxes)\n",
    "            \n",
    "            cbar=fig.colorbar(mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=-colormap_max*10**(colormap_max_base), vmax=colormap_max*10**(colormap_max_base)), \n",
    "                                                    cmap=sns.color_palette(\"icefire\", as_cmap=True)),\n",
    "                              cax=axd[plot_key+\"_inner\"],\n",
    "                              format=ticker.FuncFormatter(lambda x, pos: f'{int(x):2d}'),\n",
    "                         orientation='vertical')\n",
    "            cbar.outline.set_linewidth(0.3)\n",
    "            if colormap_max*10**(colormap_max_base)>5:\n",
    "                cbar.set_ticks([-5,0,5])\n",
    "            elif colormap_max*10**(colormap_max_base)>4:\n",
    "                cbar.set_ticks([-4, -2, 0, 2, 4])\n",
    "                cbar.set_ticklabels([-4, -2, 0, 2, 4])\n",
    "            elif colormap_max*10**(colormap_max_base)>2:\n",
    "                cbar.set_ticks([-2, -1, 0, 1, 2])\n",
    "            else:\n",
    "                pass\n",
    "            #cbar.set_major_formatter(ticker.StrMethodFormatter(\"{x:.3f}\"))\n",
    "            \n",
    "            axd[plot_key+\"_inner\"].set_title(f\" x$\\mathregular{{10^{{{-colormap_max_base}}}}}$\",\n",
    "                                    loc=\"left\",\n",
    "                                    size=13,\n",
    "                                   )\n",
    "            #axd[plot_key].tick_params(axis=\"both\", labelsize=10, width=0.4)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0)         \n",
    "            \n",
    "#             \n",
    "#             formatter.set_scientific(True)\n",
    "#             formatter.set_powerlimits((-2, 2))\n",
    "\n",
    "#             x = np.exp(np.random.uniform(size=(10, 10)) * 10)\n",
    "#             sns.heatmap(x, cbar_kws={\"format\": formatter})            \n",
    "            \n",
    "            # insert delete\n",
    "            \n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                plot_key = f\"{dataset_name}_{target_non_target}_{insert_delete}\"\n",
    "                \n",
    "                for idx4, explanation_method in enumerate([j for i in explanation_method_qualitative_main for j in i]):\n",
    "                    insertdelete_value=insertdelete_save_dict_backbone[explanation_method][adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))]\n",
    "                    axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                       linestyle='-',\n",
    "                                       c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.4])\n",
    "                \n",
    "                axd[plot_key].set_xlim(-2, 200)                                   \n",
    "\n",
    "                axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "\n",
    "\n",
    "                if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                else:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                \n",
    "                \n",
    "                \n",
    "                axd[plot_key].spines['right'].set_visible(False)\n",
    "                axd[plot_key].spines['top'].set_visible(False)   \n",
    "\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(2)  \n",
    "\n",
    "                axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=0)\n",
    "                axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                \n",
    "                if dataset_name==\"MURA\" and target_non_target==\"target\" and insert_delete==\"insert\":\n",
    "                    axd[plot_key].set_xlabel('# of Inserted Patches')\n",
    "                if dataset_name==\"MURA\" and target_non_target==\"target\" and insert_delete==\"delete\":\n",
    "                    axd[plot_key].set_xlabel('# of Deleted Patches')                    \n",
    "                \n",
    "                axd[plot_key].set_ylabel(f\"Probability\", labelpad=3)\n",
    "                axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0)\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle='-',\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method, subset_mode=\"qualitative\"))\n",
    "                         for idx4, explanation_method in enumerate([j for i in explanation_method_qualitative_main for j in i])]    \n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=3, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=3,\n",
    "                fontsize=14,\n",
    "                loc='lower center', bbox_to_anchor=(0.67, 0.23))\n",
    "    \n",
    "    fig.savefig(f\"results/plots/qualitative.jpg\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.png\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.pdf\", bbox_inches='tight')\n",
    "    fig.savefig(f\"results/plots/qualitative.svg\", bbox_inches='tight')    \n",
    "    \n",
    "    with plt.rc_context({'image.composite_image': False}):\n",
    "        fig.savefig(f\"results/plots/qualitative.pdf\", bbox_inches='tight')        \n",
    "    #fig.tight_layout()                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b498a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a98dd10",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b51b6c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b51809",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "498f692c",
   "metadata": {},
   "source": [
    "# Plot1 (ImageNette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95b1a5f8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\n",
    "dataset_name=\"ImageNette\"\n",
    "\n",
    "label_to_use=['Garbage truck', \n",
    "              'Tench', \n",
    "              'English springer',  \n",
    "              'Golf ball', \n",
    "              'Gas pump']\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    \n",
    "    classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "    explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]            \n",
    "    \n",
    "    for random_seed in [2, 3, 4, 5]:\n",
    "        \n",
    "        label_data_list=np.array([i['label'] for i in dataset_dict[dataset_name].data])\n",
    "        sample_idx_list=[np.random.RandomState(random_seed).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)]) for label_idx in [label_dict[dataset_name].index(label) for label in label_to_use]]\n",
    "        #sample_idx_list=[0]\n",
    "    \n",
    "        plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                   Paired[12][3],\n",
    "                                                                                   Paired[12][5],\n",
    "                                                                                   ]])\n",
    "        \n",
    "        plt.rcParams['font.family'] = 'PT Sans'\n",
    "        plt.rcParams[\"font.size\"] = 17\n",
    "\n",
    "        fig = plt.figure(figsize=(2.7*(len([\"image\"]+label_dict[dataset_name])+0.3*len([\"empty\"])), 3*2*len(sample_idx_list)))\n",
    "        box1 = gridspec.GridSpec(1, len([\"image\"]+[\"empty\"]+label_dict[dataset_name]), \n",
    "                                 wspace=0.1, \n",
    "                                 hspace=0,\n",
    "                                 width_ratios=[1]+[0.3]+[1]*len(label_dict[dataset_name]))\n",
    "\n",
    "        axd={}\n",
    "        for idx1, plot_type in enumerate([\"image\"]+[\"empty\"]+label_dict[dataset_name]):\n",
    "            box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, \n",
    "                                                    subplot_spec=box1[idx1], wspace=0, hspace=0.15)\n",
    "            for idx2, sample_idx in enumerate(sample_idx_list):\n",
    "                if plot_type==\"image\":\n",
    "                    box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                                        subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "                    ax=plt.Subplot(fig, box3[0])\n",
    "                    fig.add_subplot(ax)\n",
    "                    axd[f\"{sample_idx}_{plot_type}\"]=ax\n",
    "                else:\n",
    "                    box3 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                                                        subplot_spec=box2[idx2], wspace=0, hspace=0.05)                \n",
    "                    for idx3, explanation_method in enumerate([\"kernelshap\", \"ours\"]):\n",
    "                        ax=plt.Subplot(fig, box3[idx3])\n",
    "                        fig.add_subplot(ax)\n",
    "                        axd[f\"{sample_idx}_{plot_type}_{explanation_method}\"]=ax\n",
    "#                         ax=plt.Subplot(fig, box3[1])\n",
    "#                         fig.add_subplot(ax)                      \n",
    "#                         axd[f\"{sample_idx}_{plot_type}_{target_non_target}\"]=ax\n",
    "\n",
    "        for plot_key in axd.keys():\n",
    "            #continue\n",
    "            if 'empty' in plot_key:\n",
    "                axd[plot_key].set_xticks([])\n",
    "                axd[plot_key].set_yticks([])\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "        for idx1, sample_idx in enumerate(sample_idx_list):\n",
    "            dataset_item=dataset_dict[dataset_name][sample_idx]\n",
    "\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            \n",
    "            print(idx1, sample_idx)\n",
    "            \n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            explanation_ours=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "            try:\n",
    "                explanation_kernel=explanation_save_dict_backbone[\"kernelshap\"][adapt_path(path, list(explanation_save_dict_backbone[\"kernelshap\"].keys()))]['explanation']\n",
    "            except:\n",
    "                continue\n",
    "                \n",
    "            #    explanation_kernel=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "\n",
    "            for idx2, plot_type in enumerate([\"image\"]+[\"empty\"]+label_dict[dataset_name]):\n",
    "                if plot_type==\"image\":\n",
    "                    plot_key=f\"{sample_idx}_image\"\n",
    "\n",
    "                    axd[plot_key].imshow(image_unnormlized_scaled)\n",
    "\n",
    "                    axd[plot_key].set_xticks([]) \n",
    "                    axd[plot_key].set_yticks([])             \n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(1)\n",
    "                    axd[plot_key].set_title(f\"{label_to_use[idx1]}\", pad=7, zorder=10)\n",
    "                elif plot_type==\"empty\":\n",
    "                    pass\n",
    "                else:\n",
    "                    explanation_kernel_class=explanation_kernel[label_dict[dataset_name].index(plot_type)]\n",
    "                    explanation_ours_class=explanation_ours[label_dict[dataset_name].index(plot_type)]\n",
    "                    for idx3, explanation_method in enumerate([\"kernelshap\", \"ours\"]):\n",
    "                        if explanation_method==\"kernelshap\":\n",
    "                            #continue\n",
    "                            explanation_class=explanation_kernel_class#explanation_kernel[label_dict[dataset_name].index(plot_type)]\n",
    "                            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        \n",
    "                        elif explanation_method==\"ours\":\n",
    "                            explanation_class=explanation_ours_class#explanation_ours[label_dict[dataset_name].index(plot_type)]\n",
    "                            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                            \n",
    "                        else:\n",
    "                            raise\n",
    "                            \n",
    "                        #colormap_max=np.max(np.abs(np.concatenate([explanation_kernel_class,explanation_ours_class])))\n",
    "                        colormap_max=np.max(np.abs(explanation_class))\n",
    "                        #colormap_max=1/2*(np.max(np.abs(explanation_kernel_class))+np.max(np.abs(explanation_ours_class)))\n",
    "                        colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "                        \n",
    "                        explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "                        cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "                        explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "                        explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "                        #print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "\n",
    "                        image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "                        cmap=cm.get_cmap('Greys', 1000) #cm.get_cmap('Greys', 1000)\n",
    "                        image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "                        image_unnormlized_normalized[:,:,3]=0.5\n",
    "                        \n",
    "                        plot_key=f\"{sample_idx}_{plot_type}_{explanation_method}\"\n",
    "\n",
    "                        axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "                        #axd[plot_key].imshow(image_unnormlized, alpha=0.5)\n",
    "                        axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "                        \n",
    "                        axd[plot_key].set_xticks([])\n",
    "                        axd[plot_key].set_yticks([])\n",
    "                        for axis in ['top','bottom','left','right']:\n",
    "                            axd[plot_key].spines[axis].set_linewidth(1)  \n",
    "                            \n",
    "                        if idx1==0 and idx3==0:\n",
    "                            axd[plot_key].set_title(plot_type, fontsize=17)\n",
    "                        if plot_type==label_dict[dataset_name][0]:\n",
    "                            axd[plot_key].set_ylabel(explanation_method_mapper(explanation_method))\n",
    "                            \n",
    "        fig.savefig(f\"results/plots/qualitative_kernelshap_{dataset_name}_{random_seed}.jpg\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/qualitative_kernelshap_{dataset_name}_{random_seed}.png\", bbox_inches='tight')\n",
    "        with plt.rc_context({'image.composite_image': False}):\n",
    "            fig.savefig(f\"results/plots/qualitative_kernelshap_{dataset_name}_{random_seed}.pdf\", bbox_inches='tight')\n",
    "        fig.savefig(f\"results/plots/qualitative_kernelshap_{dataset_name}_{random_seed}.svg\", bbox_inches='tight')                                \n",
    "                            \n",
    "                            \n",
    "        plt.show()\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ca85486",
   "metadata": {},
   "source": [
    "# Plot2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc9f49e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_method_main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cfd4ee1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\n",
    "#dataset_name=\"ImageNette\"\n",
    "\n",
    "explanation_method_main_flatten=[\"attention_last\", \n",
    "                                 \"attention_rollout\", \n",
    "                                 \"vanillaembedding\", \n",
    "                                 \"igembedding\", \n",
    "                                 \"sgembedding\", \n",
    "                                 \"LRP\", \n",
    "                                 \"leaveoneoutclassifier\",\n",
    "                                 \"ours\"]\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    if backbone_type==\"vit_small_patch16_224\":\n",
    "        continue    \n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]            \n",
    "        label_data_list=np.array([i['label'] for i in dataset_dict[dataset_name].data])\n",
    "        prob_data_list=np.array([i['prob'] for _,i in classifier_result_list.items()])\n",
    "                \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            sample_idx_list_list=[]\n",
    "            for random_seed in [0, 1, 2, 3, 4, 5, 6, 7]:\n",
    "                sample_idx_list=[np.random.RandomState(random_seed+label_idx).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)&(np.argmax(prob_data_list, axis=1)==label_idx)]) for label_idx in range(len(label_dict[dataset_name]))]\n",
    "                sample_idx_list_list.append(sample_idx_list)\n",
    "        else:\n",
    "            sample_idx_list_list=np.random.RandomState(42).choice(np.arange(len(label_data_list[:,0]))[(label_data_list[:,0]==1)&(prob_data_list[:,0]>0.5)], replace=False, size=(10, 10))        \n",
    "        \n",
    "        for random_seed, sample_idx_list in enumerate(sample_idx_list_list):\n",
    "\n",
    "            plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                       Paired[12][3],\n",
    "                                                                                       Paired[12][5],\n",
    "                                                                                       ]])\n",
    "\n",
    "            plt.rcParams['font.family'] = 'PT Sans'\n",
    "            plt.rcParams[\"font.size\"] = 14\n",
    "\n",
    "            fig = plt.figure(figsize=(2.3*(len([\"image\"]+explanation_method_main_flatten)+0.2*len([\"empty\"])), 3*len(sample_idx_list)))\n",
    "            box1 = gridspec.GridSpec(1, len([\"image\"]+[\"empty\"]+explanation_method_main_flatten), \n",
    "                                     wspace=0.1, \n",
    "                                     hspace=0,\n",
    "                                     width_ratios=[1]+[0.2]+[1]*len(explanation_method_main_flatten))\n",
    "\n",
    "            axd={}\n",
    "            for idx1, plot_type in enumerate([\"image\"]+[\"empty\"]+explanation_method_main_flatten):\n",
    "                box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, \n",
    "                                                        subplot_spec=box1[idx1], wspace=0, hspace=0.3)\n",
    "                for idx2, sample_idx in enumerate(sample_idx_list):\n",
    "                    box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                                        subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "                    ax=plt.Subplot(fig, box3[0])\n",
    "                    fig.add_subplot(ax)\n",
    "                    axd[f\"{sample_idx}_{plot_type}\"]=ax\n",
    "\n",
    "            for plot_key in axd.keys():\n",
    "                #continue\n",
    "                if 'empty' in plot_key:\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "            for idx1, sample_idx in enumerate(sample_idx_list):\n",
    "                dataset_item=dataset_dict[dataset_name][sample_idx]\n",
    "\n",
    "                image = dataset_item[\"images\"]\n",
    "                label = dataset_item[\"labels\"]\n",
    "                path = dataset_item[\"path\"]\n",
    "\n",
    "                print(idx1, sample_idx)\n",
    "\n",
    "                image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "                assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "                image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "                classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "\n",
    "                if len(classifier_prob)==1:\n",
    "                    class_idx=0\n",
    "                else:\n",
    "                    class_idx=label\n",
    "\n",
    "                # Image\n",
    "                for idx2, plot_type in enumerate([\"image\"]+[\"empty\"]+explanation_method_main_flatten):\n",
    "                    if plot_type==\"image\":\n",
    "                        plot_key=f\"{sample_idx}_image\"\n",
    "\n",
    "                        axd[plot_key].imshow(image_unnormlized_scaled)\n",
    "\n",
    "                        axd[plot_key].set_xticks([]) \n",
    "                        axd[plot_key].set_yticks([])             \n",
    "                        for axis in ['top','bottom','left','right']:\n",
    "                            axd[plot_key].spines[axis].set_linewidth(1)\n",
    "                        if dataset_name==\"ImageNette\":\n",
    "                            axd[plot_key].set_title(f\"{label_dict[dataset_name][idx1]}\", pad=7, zorder=10)\n",
    "                        else:\n",
    "                            axd[plot_key].set_title(f\"Abnormal\", pad=7, zorder=10)\n",
    "                    elif plot_type==\"empty\":\n",
    "                        pass\n",
    "                    else:\n",
    "                        explanation=explanation_save_dict_backbone[plot_type][adapt_path(path, list(explanation_save_dict_backbone[plot_type].keys()))]['explanation']\n",
    "\n",
    "                        if len(explanation.shape)==2:\n",
    "                            explanation_class=explanation[class_idx]\n",
    "                        else:\n",
    "                            explanation_class=explanation                    \n",
    "\n",
    "                        explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                        explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                                   scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        \n",
    "\n",
    "\n",
    "                        colormap_max=np.max(np.abs(explanation_class_expanded))\n",
    "                        #colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "\n",
    "                        explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "                        cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "                        explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "                        explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "                        #print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "\n",
    "                        image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "                        cmap=cm.get_cmap('Greys', 1000) #cm.get_cmap('Greys', 1000)\n",
    "                        image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "                        image_unnormlized_normalized[:,:,3]=0.5\n",
    "\n",
    "                        plot_key=f\"{sample_idx}_{plot_type}\"\n",
    "\n",
    "                        axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "                        #axd[plot_key].imshow(image_unnormlized, alpha=0.5)\n",
    "                        axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "\n",
    "                        axd[plot_key].set_xticks([])\n",
    "                        axd[plot_key].set_yticks([])\n",
    "                        for axis in ['top','bottom','left','right']:\n",
    "                            axd[plot_key].spines[axis].set_linewidth(1)  \n",
    "\n",
    "                        #if idx1==0:\n",
    "                        axd[plot_key].set_title(explanation_method_mapper(plot_type))\n",
    "                        \n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.jpg\", bbox_inches='tight')\n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.png\", bbox_inches='tight')\n",
    "            with plt.rc_context({'image.composite_image': False}):\n",
    "                fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.pdf\", bbox_inches='tight')\n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.svg\", bbox_inches='tight')                                                        \n",
    "\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d01070f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f732296",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\n",
    "#dataset_name=\"ImageNette\"\n",
    "\n",
    "explanation_method_main_flatten=[\"attention_last\", \n",
    "                                 \"attention_rollout\", \n",
    "                                 \"riseclassifier\",\n",
    "                                 \"vanillaembedding\", \n",
    "                                 \"igembedding\", \n",
    "                                 \"sgembedding\", \n",
    "                                 \"LRP\", \n",
    "                                 \"leaveoneoutclassifier\",\n",
    "                                 \"ours\"]\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    if backbone_type==\"vit_small_patch16_224\":\n",
    "        continue    \n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]            \n",
    "        label_data_list=np.array([i['label'] for i in dataset_dict[dataset_name].data])\n",
    "        prob_data_list=np.array([i['prob'] for _,i in classifier_result_list.items()])\n",
    "                \n",
    "        if dataset_name==\"ImageNette\":\n",
    "            sample_idx_list_list=[]\n",
    "            for random_seed in [0]:\n",
    "                sample_idx_list=[np.random.RandomState(random_seed+label_idx).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)&(np.argmax(prob_data_list, axis=1)==label_idx)]) for label_idx in range(len(label_dict[dataset_name]))]\n",
    "                sample_idx_list_list.append(sample_idx_list)\n",
    "        else:\n",
    "            sample_idx_list_list=np.random.RandomState(42).choice(np.arange(len(label_data_list[:,0]))[(label_data_list[:,0]==1)&(prob_data_list[:,0]>0.5)], replace=False, size=(10, 10))        \n",
    "        \n",
    "        for random_seed, sample_idx_list in enumerate(sample_idx_list_list):\n",
    "\n",
    "            plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                       Paired[12][3],\n",
    "                                                                                       Paired[12][5],\n",
    "                                                                                       ]])\n",
    "\n",
    "            plt.rcParams['font.family'] = 'PT Sans'\n",
    "            plt.rcParams[\"font.size\"] = 14\n",
    "\n",
    "            fig = plt.figure(figsize=(2.3*(len([\"image\"]+explanation_method_main_flatten)+0.2*len([\"empty\"])), 3*len(sample_idx_list)))\n",
    "            box1 = gridspec.GridSpec(1, len([\"image\"]+[\"empty\"]+explanation_method_main_flatten), \n",
    "                                     wspace=0.1, \n",
    "                                     hspace=0,\n",
    "                                     width_ratios=[1]+[0.2]+[1]*len(explanation_method_main_flatten))\n",
    "\n",
    "            axd={}\n",
    "            for idx1, plot_type in enumerate([\"image\"]+[\"empty\"]+explanation_method_main_flatten):\n",
    "                box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, \n",
    "                                                        subplot_spec=box1[idx1], wspace=0, hspace=0.3)\n",
    "                for idx2, sample_idx in enumerate(sample_idx_list):\n",
    "                    box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                                        subplot_spec=box2[idx2], wspace=0, hspace=0)\n",
    "                    ax=plt.Subplot(fig, box3[0])\n",
    "                    fig.add_subplot(ax)\n",
    "                    axd[f\"{sample_idx}_{plot_type}\"]=ax\n",
    "\n",
    "            for plot_key in axd.keys():\n",
    "                #continue\n",
    "                if 'empty' in plot_key:\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "            for idx1, sample_idx in enumerate(sample_idx_list):\n",
    "                dataset_item=dataset_dict[dataset_name][sample_idx]\n",
    "\n",
    "                image = dataset_item[\"images\"]\n",
    "                label = dataset_item[\"labels\"]\n",
    "                path = dataset_item[\"path\"]\n",
    "\n",
    "                print(idx1, sample_idx)\n",
    "\n",
    "                image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "                assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "                image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "                classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "\n",
    "                if len(classifier_prob)==1:\n",
    "                    class_idx=0\n",
    "                else:\n",
    "                    class_idx=label\n",
    "\n",
    "                # Image\n",
    "                for idx2, plot_type in enumerate([\"image\"]+[\"empty\"]+explanation_method_main_flatten):\n",
    "                    if plot_type==\"image\":\n",
    "                        plot_key=f\"{sample_idx}_image\"\n",
    "\n",
    "                        axd[plot_key].imshow(image_unnormlized_scaled)\n",
    "\n",
    "                        axd[plot_key].set_xticks([]) \n",
    "                        axd[plot_key].set_yticks([])             \n",
    "                        for axis in ['top','bottom','left','right']:\n",
    "                            axd[plot_key].spines[axis].set_linewidth(1)\n",
    "                        if dataset_name==\"ImageNette\":\n",
    "                            axd[plot_key].set_title(f\"{label_dict[dataset_name][idx1]}\", pad=7, zorder=10)\n",
    "                        else:\n",
    "                            axd[plot_key].set_title(f\"Abnormal\", pad=7, zorder=10)\n",
    "                    elif plot_type==\"empty\":\n",
    "                        pass\n",
    "                    else:\n",
    "                        explanation=explanation_save_dict_backbone[plot_type][adapt_path(path, list(explanation_save_dict_backbone[plot_type].keys()))]['explanation']\n",
    "\n",
    "                        if len(explanation.shape)==2:\n",
    "                            explanation_class=explanation[class_idx]\n",
    "                        else:\n",
    "                            explanation_class=explanation                    \n",
    "\n",
    "                        explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                        explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                                   scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        \n",
    "\n",
    "\n",
    "                        colormap_max=np.max(np.abs(explanation_class_expanded))\n",
    "                        #colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "\n",
    "                        explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "                        cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "                        explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "                        explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "                        #print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "\n",
    "                        image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "                        cmap=cm.get_cmap('Greys', 1000) #cm.get_cmap('Greys', 1000)\n",
    "                        image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "                        image_unnormlized_normalized[:,:,3]=0.5\n",
    "\n",
    "                        plot_key=f\"{sample_idx}_{plot_type}\"\n",
    "\n",
    "                        axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "                        #axd[plot_key].imshow(image_unnormlized, alpha=0.5)\n",
    "                        axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "\n",
    "                        axd[plot_key].set_xticks([])\n",
    "                        axd[plot_key].set_yticks([])\n",
    "                        for axis in ['top','bottom','left','right']:\n",
    "                            axd[plot_key].spines[axis].set_linewidth(1)  \n",
    "\n",
    "                        #if idx1==0:\n",
    "                        axd[plot_key].set_title(explanation_method_mapper(plot_type))\n",
    "                        \n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.jpg\", bbox_inches='tight')\n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.png\", bbox_inches='tight')\n",
    "            with plt.rc_context({'image.composite_image': False}):\n",
    "                fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.pdf\", bbox_inches='tight')\n",
    "            fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.svg\", bbox_inches='tight')                                                        \n",
    "\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27edb788",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_save_dict_backbone[\"riseclassifier\"][adapt_path(path, list(explanation_save_dict_backbone[plot_type].keys()))]['explanation']\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63a37d5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f4319d3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e5e63dd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b3e053",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1db046bb",
   "metadata": {},
   "source": [
    "# Plot3 (ImageNette)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eec738e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\n",
    "explanation_method_main_flatten=[j for i in explanation_method_main for j in i]\n",
    "\n",
    "explanation_method_main_flatten=[\"attention_rollout\", \n",
    "                                 \"vanillaembedding\", \n",
    "                                 \"igembedding\", \n",
    "                                 \"sgembedding\", \n",
    "                                 \"LRP\", \n",
    "                                 #\"riseclassifier\", \n",
    "                                 \"ours\"]\n",
    "\n",
    "explanation_method_main_flatten=[\"attention_last\", \n",
    "                                 \"attention_rollout\", \n",
    "                                 \"vanillaembedding\", \n",
    "                                 \"igembedding\", \n",
    "                                 \"sgembedding\", \n",
    "                                 \"LRP\", \n",
    "                                 \"leaveoneoutclassifier\",\n",
    "                                 \"ours\"]\n",
    "    \n",
    "# 11, 20, 43\n",
    "#\n",
    "#\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    if backbone_type==\"vit_small_patch16_224\":\n",
    "        continue\n",
    "    for dataset_name in [\"ImageNette\"]:\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]\n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type]\n",
    "        for dataset_idx, dataset_item in enumerate(dataset_dict[dataset_name]):\n",
    "            print(dataset_idx)\n",
    "            if dataset_idx==50:\n",
    "                break\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx_list=[0]\n",
    "            else:\n",
    "                target_idx=label\n",
    "                insert_data=insertdelete_save_dict_backbone[\"ours\"][adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys()))]['insert']\n",
    "\n",
    "                non_target_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                non_target_idx.remove(target_idx)\n",
    "                non_target_idx=non_target_idx[0]\n",
    "                class_idx_list=[target_idx, non_target_idx]\n",
    "\n",
    "            # Figure setting                        \n",
    "            #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "            \n",
    "            plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                        Paired[12][3],\n",
    "                                                                                        Paired[12][5],\n",
    "                                                                                        Paired[12][7],\n",
    "                                                                                        Paired[12][9],\n",
    "                                                                                        Paired[12][11]\n",
    "                                                                                        ]])             \n",
    "\n",
    "            plt.rcParams['font.family'] = 'PT Sans'\n",
    "            plt.rcParams[\"font.size\"] = 16\n",
    "        \n",
    "            mosaic_grid=[]\n",
    "            if len(classifier_prob)==1:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([\"image\", \"image\"])                \n",
    "            else:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([\"empty_left\", \"image\", \"image\", \"empty_right\"])                \n",
    "            mosaic_grid.append([\"empty_attribution_map\"]*2*len(class_idx_list))\n",
    "            for explanation_method in explanation_method_main_flatten:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([f\"{class_idx}_{explanation_method}\" for class_idx in class_idx_list for j in range(2)])\n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                mosaic_grid.append([f\"empty_{insert_delete}\"]*2*len(class_idx_list))\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([f\"{class_idx}_{insert_delete}\" for class_idx in class_idx_list for j in range(2)])\n",
    "\n",
    "            fig, axd = plt.subplot_mosaic(mosaic_grid, \n",
    "                                          figsize=(3*len(class_idx_list),3*(3+len(explanation_method_main_flatten))),\n",
    "                                         gridspec_kw={\"height_ratios\": [1]*2 + [0.1] + [1]*2*len(explanation_method_main_flatten)+[0.05]+[1]*2+[0.05]+[1]*2})\n",
    "            \n",
    "            for plot_key in axd.keys():\n",
    "                #continue\n",
    "                if 'empty' in plot_key:\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(0) \n",
    "            \n",
    "            axd[\"image\"].imshow(image_unnormlized_scaled)\n",
    "            axd[\"image\"].set_xticks([])\n",
    "            axd[\"image\"].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[\"image\"].spines[axis].set_linewidth(2)\n",
    "            axd[\"image\"].set_title(f\"True label: {label_dict[dataset_name][label]}\", pad=10, fontsize=15)\n",
    "            #axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10, fontsize=15)\n",
    "            \n",
    "            for idx1, class_idx in enumerate(class_idx_list):\n",
    "                for idx2, explanation_method in enumerate(explanation_method_main_flatten):\n",
    "                    explanation=explanation_save_dict_backbone[explanation_method][adapt_path(path, list(explanation_save_dict_backbone[explanation_method].keys()))]['explanation']\n",
    "                    \n",
    "                    if len(explanation.shape)==2:\n",
    "                        explanation_class=explanation[class_idx]\n",
    "                    else:\n",
    "                        explanation_class=explanation\n",
    "\n",
    "                    plot_key=f\"{class_idx}_{explanation_method}\"\n",
    "                    \n",
    "                    explanation_class=explanation_class+np.random.RandomState(42).uniform(low=0, high=1e-40, size=explanation_class.shape)\n",
    "                    \n",
    "                    explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                    explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                               scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "\n",
    "                    #colormap_max, colormap_tick = get_max_tick(np.max(np.abs(explanation_class_expanded)))\n",
    "                    #colormap_max=float(\"{:.1e}\".format(np.max(np.abs(explanation_class_expanded))))\n",
    "                    colormap_max=np.max(np.abs(explanation_class_expanded))\n",
    "                    #colormap_max_base=int(np.round(-np.log10(colormap_max)+0.5))\n",
    "                    #colormap_max_base=int(np.round(-np.log10(colormap_max)-np.log10(0.5)+0.5))\n",
    "                    print(colormap_max_base)\n",
    "                    if explanation_method==\"ours\":\n",
    "                        explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "                        cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "                        explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "                        explanation_class_expanded_heatmap[:,:,3]=0.6\n",
    "                    else:\n",
    "                        explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)\n",
    "                        cmap=sns.color_palette(\"icefire\", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)\n",
    "                        explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]\n",
    "                        explanation_class_expanded_heatmap[:,:,3]=0.6                        \n",
    "                        \n",
    "                    #print(f'max {explanation_class_expanded.max():.3f} min {explanation_class_expanded.min():.3f}')\n",
    "                    \n",
    "                    image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3\n",
    "                    cmap=cm.get_cmap('Greys', 1000)#cm.get_cmap('Greys', 1000)\n",
    "                    image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]\n",
    "                    image_unnormlized_normalized[:,:,3]=0.5                    \n",
    "                    \n",
    "                    axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)\n",
    "                    axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)\n",
    "\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    \n",
    "                    if idx2==0:\n",
    "                        if len(classifier_prob)!=1:\n",
    "                            axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10, fontsize=15)\n",
    "                    \n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(2)                     \n",
    "                    \n",
    "                    if idx1==0:\n",
    "                        axd[plot_key].set_ylabel(explanation_method_mapper(explanation_method))\n",
    "                        \n",
    "                for idx2, insert_delete in enumerate(['insert', 'delete']):\n",
    "                    plot_key=f\"{class_idx}_{insert_delete}\"\n",
    "                    \n",
    "                    for idx3, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                        for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                            path_mapped_insertdelete=adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))\n",
    "                            insertdelete_value=insertdelete_save_dict_backbone[explanation_method][path_mapped_insertdelete]                            \n",
    "\n",
    "                            axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                                linestyle=[':','-.','-'][idx3],\n",
    "                                                c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4])\n",
    "                    #if idx1==0:\n",
    "                    #    axd[plot_key].set_ylim(0, 1.05)\n",
    "                    axd[plot_key].set_xlim(-2, 200)                                   \n",
    "                                \n",
    "                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                    axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "                    \n",
    "                    \n",
    "                    if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                    else:\n",
    "                        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)\n",
    "                    \n",
    "                    axd[plot_key].spines['right'].set_visible(False)\n",
    "                    axd[plot_key].spines['top'].set_visible(False)   \n",
    "                    \n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(2)\n",
    "                        \n",
    "                    if insert_delete==\"insert\":\n",
    "                        axd[plot_key].set_xticks([])   \n",
    "                        \n",
    "                    axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=-1)\n",
    "                    axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                        \n",
    "                    if idx1==0:\n",
    "                        axd[plot_key].set_ylabel(f\"Probability\")# ({insert_delete_mapper(insert_delete, verbose=False)})\")#, labelpad=-10)                        \n",
    "                    axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0)#\n",
    "                        \n",
    "            legend_elements = [Line2D([0], [0],\n",
    "                                      linestyle=[':','-.','-'][idx1],\n",
    "                                      color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2],\n",
    "                                      linewidth=3,\n",
    "                                      label=explanation_method_mapper(explanation_method))\n",
    "                                 for idx1, explanation_method_category in enumerate(explanation_method_main) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "\n",
    "            fig.legend(handles=legend_elements, \n",
    "                        ncol=2, \n",
    "                        handlelength=3,\n",
    "                        handletextpad=0.6, \n",
    "                        columnspacing=1.5,\n",
    "                        fontsize=14,\n",
    "                        loc='lower center', bbox_to_anchor=(0.5, 0.055))\n",
    "            \n",
    "            fig.savefig(f\"results/plots/attribution_map_{dataset_name}/{backbone_type}_{dataset_idx:04d}.jpg\", bbox_inches='tight')            \n",
    "            \n",
    "            with plt.rc_context({'image.composite_image': False}):\n",
    "                fig.savefig(f\"results/plots/attribution_map_{dataset_name}/{backbone_type}_{dataset_idx:04d}.pdf\", bbox_inches='tight')\n",
    "                #fig.savefig(f\"results/plots/qualitative_baselines_{dataset_name}_{random_seed}.pdf\", bbox_inches='tight')            \n",
    "            \n",
    "            if dataset_name==\"MURA\":\n",
    "                if classifier_prob[0]>0.5 and label==False:\n",
    "                    fig.savefig(f\"results/plots/attribution_map_{dataset_name}_false_positive/{backbone_type}_{dataset_idx:04d}.jpg\", bbox_inches='tight')\n",
    "                if classifier_prob[0]<=0.5 and label==True:\n",
    "                    fig.savefig(f\"results/plots/attribution_map_{dataset_name}_false_negative/{backbone_type}_{dataset_idx:04d}.jpg\", bbox_inches='tight')                    \n",
    "            #plt.close(fig)\n",
    "            #fig.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0)\n",
    "            #fig.subplots_adjust(hspace=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72d089fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "865fa5f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1228418",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca4e520e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cmapy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "741b8559",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loaded_all.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede75c19",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# dataset_idx_dict={\"ImageNette\": {\"target\": 43,\n",
    "#                             \"non-target\":43},\n",
    "#              \"MURA\": {\"target\": 8,\n",
    "#                       \"non-target\": 9},\n",
    "#             }\n",
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "\n",
    "dataset_idx_dict={\"ImageNette\": {\"target\": 97,\n",
    "                            \"non-target\": 97},\n",
    "                  \"MURA\": {\"target\": 15,\n",
    "                           \"non-target\": 17},\n",
    "                 }\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])\n",
    "    \n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 12\n",
    "\n",
    "    fig = plt.figure(figsize=(3*(1+1+0.01+1.7+0.01+1.7), 3*4))\n",
    "    box1 = gridspec.GridSpec(1, 6, wspace=0.3, hspace=0, width_ratios=[1, 1, 0.01, 1.7, 0.01, 1.7])\n",
    "\n",
    "    axd={}\n",
    "    for idx1, plot_type in enumerate([\"image\", \"map\", \"empty1\", \"insert\", \"empty2\", \"delete\"]):\n",
    "        box2 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                        subplot_spec=box1[idx1], wspace=0, hspace=0.15)\n",
    "        for idx2, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "            if plot_type==\"image\" and dataset_name==\"ImageNette\":\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(1, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.1)\n",
    "                ax = plt.Subplot(fig, box3[0])\n",
    "                fig.add_subplot(ax)\n",
    "\n",
    "                plot_key=f\"{dataset_name}_target_{plot_type}\"\n",
    "                axd[plot_key]=ax            \n",
    "                plot_key=f\"{dataset_name}_non-target_{plot_type}\"\n",
    "                axd[plot_key]=ax                            \n",
    "            else:\n",
    "                box3 = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                                    subplot_spec=box2[idx2], wspace=0.1, hspace=0.3)\n",
    "                for idx3, target_non_target in enumerate([\"target\", \"non-target\"]):\n",
    "                    ax = plt.Subplot(fig, box3[idx3])\n",
    "                    fig.add_subplot(ax)\n",
    "\n",
    "                    plot_key=f\"{dataset_name}_{target_non_target}_{plot_type}\"\n",
    "                    axd[plot_key]=ax\n",
    "    \n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]        \n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type]        \n",
    "        \n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):    \n",
    "            dataset_item=dataset_dict[dataset_name][dataset_idx_dict[dataset_name][target_non_target]]\n",
    "\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx=0\n",
    "            else:\n",
    "                #target_idx=label\n",
    "                if target_non_target==\"target\":\n",
    "                    class_idx=label\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    insert_data=insertdelete_save_dict_backbone[\"ours\"][adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys()))]['insert']\n",
    "                    class_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                    class_idx.remove(label)\n",
    "                    class_idx=class_idx[0]\n",
    "                else:\n",
    "                    raise\n",
    "            print(class_idx)\n",
    "\n",
    "            # Image\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_image\"\n",
    "\n",
    "            axd[plot_key].imshow(image_unnormlized_scaled)\n",
    "\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2)\n",
    "                \n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][np.argmax(classifier_prob)]}\", pad=7)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                axd[plot_key].set_title(f\"True: {label_dict[dataset_name][label]}\\nPred: {label_dict[dataset_name][int(classifier_prob>0.5)]}\", pad=7)\n",
    "            else:\n",
    "                raise            \n",
    "                \n",
    "            # Map\n",
    "            explanation=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "            explanation_class=explanation[class_idx]\n",
    "\n",
    "            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "            explanation_class_expanded_normalized=(explanation_class_expanded-explanation_class_expanded.min())/(explanation_class_expanded.max()-explanation_class_expanded.min())\n",
    "            #img_colorized = cv2.applyColorMap(img, cmapy.cmap('viridis'))\n",
    "#             explanation_class_expanded_normalized_heatmap=cv2.applyColorMap(np.uint8(256 * explanation_class_expanded_normalized), \n",
    "#                                                                             cv2.COLORMAP_JET)/256\n",
    "            explanation_class_expanded_normalized_heatmap=cv2.applyColorMap(np.uint8(256 * explanation_class_expanded_normalized), \n",
    "                                                                            cmapy.cmap('seismic'))/256\n",
    "            \n",
    "            image_explanation = image_unnormlized_scaled+explanation_class_expanded_normalized_heatmap\n",
    "            image_explanation = image_explanation/image_explanation.max()\n",
    "            image_explanation = cv2.cvtColor(np.uint8(256 * image_explanation), cv2.COLOR_RGB2BGR)            \n",
    "            \n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_map\"\n",
    "            \n",
    "            axd[plot_key].imshow(image_explanation)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2)\n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10)\n",
    "            elif dataset_name==\"MURA\":\n",
    "                axd[plot_key].set_title(\"Abnormal\", pad=10)\n",
    "            else:\n",
    "                raise\n",
    "\n",
    "            \n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                plot_key = f\"{dataset_name}_{target_non_target}_{insert_delete}\"\n",
    "                \n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main_random):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):                                                \n",
    "                        insertdelete_value=insertdelete_save_dict_backbone[explanation_method][adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))]                            \n",
    "\n",
    "                        axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                           linestyle=[':','-.','-','-'][idx3],\n",
    "                                           c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4] if explanation_method!=\"random\" else [0,0,0,0.4])\n",
    "                \n",
    "                axd[plot_key].set_xlim(-2, 200)                                   \n",
    "\n",
    "                axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "\n",
    "\n",
    "                if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                else:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                \n",
    "                \n",
    "                \n",
    "                axd[plot_key].spines['right'].set_visible(False)\n",
    "                axd[plot_key].spines['top'].set_visible(False)   \n",
    "\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(2)  \n",
    "\n",
    "                axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=-1)\n",
    "                axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                \n",
    "                axd[plot_key].set_ylabel(f\"Probability\", labelpad=-1)\n",
    "                axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0)\n",
    "                \n",
    "#     legend_elements = [Line2D([0], [0],\n",
    "#                               linestyle=[':','-.','-','-'][idx1],\n",
    "#                               color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "#                               linewidth=3,\n",
    "#                               label=explanation_method_mapper(explanation_method))\n",
    "#                          for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "\n",
    "#     fig.legend(handles=legend_elements, \n",
    "#                 ncol=4, \n",
    "#                 handlelength=3,\n",
    "#                 handletextpad=0.6, \n",
    "#                 columnspacing=1.5,\n",
    "#                 fontsize=14,\n",
    "#                 loc='lower center', bbox_to_anchor=(0.5, 0.0)) \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "    \n",
    "    legend_elements.insert(2, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    legend_elements.insert(2, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    legend_elements.insert(2, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    legend_elements.insert(2, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    \n",
    "    legend_elements.insert(15, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    legend_elements.insert(15, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    legend_elements.insert(15, Line2D([0], [0], label='', color=(1,1,1,1)))\n",
    "    \n",
    "    legend_elements.insert(20, Line2D([0], [0], label='', color=(1,1,1,1)))    \n",
    "    legend_elements.insert(20, Line2D([0], [0], label='', color=(1,1,1,1)))    \n",
    "    legend_elements.insert(20, Line2D([0], [0], label='', color=(1,1,1,1)))    \n",
    "    legend_elements.insert(20, Line2D([0], [0], label='', color=(1,1,1,1)))    \n",
    "    legend_elements.insert(20, Line2D([0], [0], label='', color=(1,1,1,1)))    \n",
    "    \n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=4, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=14,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, -0.07))\n",
    "    #fig.tight_layout()                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "933aaa48",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ab63bd1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3911bfe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40330aa9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69ae3507",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c4a9c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cddb7f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "        \n",
    "        \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2] if explanation_method!=\"random\" else [0,0,0,0.8],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main_random) for idx2, explanation_method in enumerate(explanation_method_category)]\n"
   ]
  },
  {
   "cell_type": "raw",
   "id": "668e1549",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2eec6e5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71773164",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1f0920",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "fig = plt.figure(figsize=(10, 8))\n",
    "outer = gridspec.GridSpec(2, 2, wspace=0.2, hspace=0.2)\n",
    "\n",
    "for i in range(4):\n",
    "    inner = gridspec.GridSpecFromSubplotSpec(2, 1,\n",
    "                    subplot_spec=outer[i], wspace=0.1, hspace=0.1)\n",
    "\n",
    "    for j in range(2):\n",
    "        ax = plt.Subplot(fig, inner[j])\n",
    "        t = ax.text(0.5,0.5, 'outer=%d, inner=%d' % (i, j))\n",
    "        t.set_ha('center')\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        fig.add_subplot(ax)\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a22188",
   "metadata": {},
   "outputs": [],
   "source": [
    "outer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08aaadc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "inner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a8f08f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32ed7c50",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_idx_dict={\"ImageNette\": {\"target\": 4,\n",
    "                            \"non-target\":4},\n",
    "             \"MURA\": {\"target\": 8,\n",
    "                      \"non-target\": 9},\n",
    "            }\n",
    "\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                Paired[12][3],\n",
    "                                                                                Paired[12][5],\n",
    "                                                                                Paired[12][7],\n",
    "                                                                                Paired[12][9],\n",
    "                                                                                Paired[12][11]\n",
    "                                                                                ]])\n",
    "\n",
    "    plt.rcParams['font.family'] = 'PT Sans'\n",
    "    plt.rcParams[\"font.size\"] = 12\n",
    "\n",
    "    \n",
    "    fig=plt.figure()\n",
    "    #fig.add\n",
    "    \n",
    "    axd[\"ImageNette_target_image\"] = axd[\"ImageNette_image\"]\n",
    "    axd[\"ImageNette_non-target_image\"] = axd[\"ImageNette_image\"]\n",
    "\n",
    "    for plot_key in axd.keys():\n",
    "        #continue\n",
    "        if 'empty' in plot_key:\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(0) \n",
    "\n",
    "    for idx1, dataset_name in enumerate([\"ImageNette\", \"MURA\"]):\n",
    "        classifier_result_list = data_loaded_all[\"4_0_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"4_1_explanation_generate\"][dataset_name][backbone_type]        \n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_2_insert_delete\"][dataset_name][backbone_type]        \n",
    "        \n",
    "        for idx2, target_non_target in enumerate([\"target\", \"non-target\"]):    \n",
    "            dataset_item=dataset_dict[dataset_name][dataset_idx_dict[dataset_name][target_non_target]]\n",
    "\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx=0\n",
    "            else:\n",
    "                #target_idx=label\n",
    "                if target_non_target==\"target\":\n",
    "                    class_idx=label\n",
    "                elif target_non_target==\"non-target\":\n",
    "                    insert_data=insertdelete_save_dict_backbone[\"ours\"][adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys()))]['insert']\n",
    "                    class_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                    class_idx.remove(label)\n",
    "                    class_idx=class_idx[0]\n",
    "                else:\n",
    "                    raise\n",
    "                \n",
    "\n",
    "            # Image\n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_image\"\n",
    "\n",
    "            axd[plot_key].imshow(image_unnormlized_scaled)\n",
    "\n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2)\n",
    "                \n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(f\"{label_dict[dataset_name][label]}\")\n",
    "            elif dataset_name==\"MURA\":\n",
    "                axd[plot_key].set_title(\"True positive\", pad=10)\n",
    "            else:\n",
    "                raise            \n",
    "                \n",
    "            # Map\n",
    "            explanation=explanation_save_dict_backbone[\"ours\"][adapt_path(path, list(explanation_save_dict_backbone[\"ours\"].keys()))]['explanation']\n",
    "            explanation_class=explanation[class_idx]\n",
    "\n",
    "            explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "            explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                       scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "            explanation_class_expanded_normalized=(explanation_class_expanded-explanation_class_expanded.min())/(explanation_class_expanded.max()-explanation_class_expanded.min())\n",
    "            explanation_class_expanded_normalized_heatmap=cv2.applyColorMap(np.uint8(256 * explanation_class_expanded_normalized), cv2.COLORMAP_JET)/256\n",
    "            #image_unnormlized_scaled_grey=image_unnormlized_scaled\n",
    "            image_explanation = image_unnormlized_scaled+explanation_class_expanded_normalized_heatmap\n",
    "            image_explanation = image_explanation/image_explanation.max()\n",
    "            image_explanation = cv2.cvtColor(np.uint8(256 * image_explanation), cv2.COLOR_RGB2BGR)            \n",
    "            \n",
    "            plot_key=f\"{dataset_name}_{target_non_target}_map\"\n",
    "            \n",
    "            axd[plot_key].imshow(image_explanation)\n",
    "            \n",
    "            axd[plot_key].set_xticks([])\n",
    "            axd[plot_key].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[plot_key].spines[axis].set_linewidth(2)\n",
    "            if dataset_name==\"ImageNette\":\n",
    "                axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10)\n",
    "\n",
    "            \n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                plot_key = f\"{dataset_name}_{target_non_target}_{insert_delete}\"\n",
    "                \n",
    "                for idx3, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                    for idx4, explanation_method in enumerate(explanation_methods_category):                                                \n",
    "                        insertdelete_value=insertdelete_save_dict_backbone[explanation_method][adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))]                            \n",
    "\n",
    "                        axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                            linestyle=[':','-.','-'][idx3],\n",
    "                                            c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4])\n",
    "                \n",
    "                axd[plot_key].set_xlim(-2, 200)                                   \n",
    "\n",
    "                axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "\n",
    "\n",
    "                if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                else:\n",
    "                    axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                    axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                    axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                \n",
    "                \n",
    "                \n",
    "                axd[plot_key].spines['right'].set_visible(False)\n",
    "                axd[plot_key].spines['top'].set_visible(False)   \n",
    "\n",
    "                for axis in ['top','bottom','left','right']:\n",
    "                    axd[plot_key].spines[axis].set_linewidth(2)\n",
    "\n",
    "                if insert_delete==\"insert\":\n",
    "                    axd[plot_key].set_xticks([])   \n",
    "\n",
    "                axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=-1)\n",
    "                axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                \n",
    "                axd[plot_key].set_ylabel(f\"Probability\")\n",
    "                axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0)\n",
    "                \n",
    "    legend_elements = [Line2D([0], [0],\n",
    "                              linestyle=[':','-.','-'][idx1],\n",
    "                              color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2],\n",
    "                              linewidth=3,\n",
    "                              label=explanation_method_mapper(explanation_method))\n",
    "                         for idx1, explanation_method_category in enumerate(explanation_method_main) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "\n",
    "    fig.legend(handles=legend_elements, \n",
    "                ncol=4, \n",
    "                handlelength=3,\n",
    "                handletextpad=0.6, \n",
    "                columnspacing=1.5,\n",
    "                fontsize=14,\n",
    "                loc='lower center', bbox_to_anchor=(0.5, 0.0))                \n",
    "    #fig.tight_layout()                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "281d77eb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1d5bbe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e42c1e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_dict[dataset_name][class_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54010f67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a89fc61",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "265c6d1f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5a5a751",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] \n",
    "    \n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        classifier_result_list = data_loaded_all[\"1_classifier_evaluate\"][dataset_name][backbone_type]\n",
    "        explanation_save_dict_backbone = data_loaded_all[\"3_explanation_generate\"][dataset_name][backbone_type]\n",
    "        insertdelete_save_dict_backbone = data_loaded_all[\"4_insert_delete\"][dataset_name][backbone_type]\n",
    "        for dataset_idx, dataset_item in enumerate(dataset_dict[dataset_name]):\n",
    "            if dataset_idx==1:\n",
    "                break\n",
    "            print(dataset_idx)\n",
    "            image = dataset_item[\"images\"]\n",
    "            label = dataset_item[\"labels\"]\n",
    "            path = dataset_item[\"path\"]\n",
    "            image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)\n",
    "            assert image_unnormlized.min()>0 and image_unnormlized.max()<1\n",
    "            #image_unnormlized=cv2.applyColorMap(np.uint8(256 * image_unnormlized), cv2.COLOR_RGB2GRAY)/256\n",
    "            image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())\n",
    "            classifier_prob=classifier_result_list[adapt_path(path, list(classifier_result_list.keys()))][\"prob\"]\n",
    "            \n",
    "            \n",
    "            if len(classifier_prob)==1:\n",
    "                class_idx_list=[0]\n",
    "            else:\n",
    "                target_idx=label\n",
    "\n",
    "                path_mapped_insertdelete=adapt_path(path, list(insertdelete_save_dict_backbone[\"ours\"].keys())[0])\n",
    "                insert_data=insertdelete_save_dict_backbone[\"ours\"][path_mapped_insertdelete]['insert']\n",
    "\n",
    "                non_target_idx=pd.Series(insert_data[:,:10].mean(axis=1)).sort_values(ascending=False).index.tolist()#[1]\n",
    "                non_target_idx.remove(target_idx)\n",
    "                non_target_idx=non_target_idx[0]\n",
    "                class_idx_list=[target_idx, non_target_idx]\n",
    "\n",
    "            # Figure setting                        \n",
    "            #plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in Set1[6]])    \n",
    "            \n",
    "            plt.rcParams[\"axes.prop_cycle\"]=cycler('color', [np.array(i)/256 for i in [Paired[12][1], \n",
    "                                                                                        Paired[12][3],\n",
    "                                                                                        Paired[12][5],\n",
    "                                                                                        Paired[12][7],\n",
    "                                                                                        Paired[12][9],\n",
    "                                                                                        Paired[12][11]\n",
    "                                                                                        ]])             \n",
    "\n",
    "            plt.rcParams['font.family'] = 'PT Sans'\n",
    "            plt.rcParams[\"font.size\"] = 16\n",
    "        \n",
    "            mosaic_grid=[]\n",
    "            if len(classifier_prob)==1:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([\"image\", \"image\"])                \n",
    "            else:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([\"empty_left\", \"image\", \"image\", \"empty_right\"])                \n",
    "            mosaic_grid.append([\"empty_attribution_map\"]*2*len(class_idx_list))\n",
    "            for explanation_method in explanation_method_main_flatten:\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([f\"{class_idx}_{explanation_method}\" for class_idx in class_idx_list for j in range(2)])\n",
    "            for insert_delete in [\"insert\", \"delete\"]:\n",
    "                mosaic_grid.append([f\"empty_{insert_delete}\"]*2*len(class_idx_list))\n",
    "                for i in range(2):\n",
    "                    mosaic_grid.append([f\"{class_idx}_{insert_delete}\" for class_idx in class_idx_list for j in range(2)])\n",
    "\n",
    "            fig, axd = plt.subplot_mosaic(mosaic_grid, \n",
    "                                          figsize=(3*len(class_idx_list),3*(3+len(explanation_method_main_flatten))),\n",
    "                                         gridspec_kw={\"height_ratios\": [1]*2 + [0.1] + [1]*2*len(explanation_method_main_flatten)+[0.05]+[1]*2+[0.05]+[1]*2})\n",
    "            \n",
    "            for plot_key in axd.keys():\n",
    "                #continue\n",
    "                if 'empty' in plot_key:\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(0) \n",
    "            \n",
    "            axd[\"image\"].imshow(image_unnormlized_scaled)\n",
    "            axd[\"image\"].set_xticks([])\n",
    "            axd[\"image\"].set_yticks([])\n",
    "            for axis in ['top','bottom','left','right']:\n",
    "                axd[\"image\"].spines[axis].set_linewidth(2)\n",
    "            axd[\"image\"].set_title(f\"True label: {label_dict[dataset_name][label]}\", pad=10, fontsize=15)\n",
    "            #axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10, fontsize=15)\n",
    "            \n",
    "            for idx1, class_idx in enumerate(class_idx_list):\n",
    "                for idx2, explanation_method in enumerate(explanation_method_main_flatten):\n",
    "                    explanation=explanation_save_dict_backbone[explanation_method][adapt_path(path, list(explanation_save_dict_backbone[explanation_method].keys()))]['explanation']\n",
    "                    \n",
    "                    if len(explanation.shape)==2:\n",
    "                        explanation_class=explanation[class_idx]\n",
    "                    else:\n",
    "                        explanation_class=explanation\n",
    "                    \n",
    "                    plot_key=f\"{class_idx}_{explanation_method}\"\n",
    "                    \n",
    "                    explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)\n",
    "                    explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), \n",
    "                                                                               scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)\n",
    "                    explanation_class_expanded_normalized=(explanation_class_expanded-explanation_class_expanded.min())/(explanation_class_expanded.max()-explanation_class_expanded.min())\n",
    "                    explanation_class_expanded_normalized_heatmap=cv2.applyColorMap(np.uint8(256 * explanation_class_expanded_normalized), cv2.COLORMAP_JET)/256\n",
    "                    #image_unnormlized_scaled_grey=image_unnormlized_scaled\n",
    "                    image_explanation = image_unnormlized_scaled+explanation_class_expanded_normalized_heatmap\n",
    "                    image_explanation = image_explanation/image_explanation.max()\n",
    "                    image_explanation = cv2.cvtColor(np.uint8(256 * image_explanation), cv2.COLOR_RGB2BGR)\n",
    "                    \n",
    "                    axd[plot_key].imshow(image_explanation)\n",
    "\n",
    "                    axd[plot_key].set_xticks([])\n",
    "                    axd[plot_key].set_yticks([])\n",
    "                    \n",
    "                    if idx2==0:\n",
    "                        if len(classifier_prob)!=1:\n",
    "                            axd[plot_key].set_title(label_dict[dataset_name][class_idx], pad=10, fontsize=15)\n",
    "                    \n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(2)                     \n",
    "                    \n",
    "                    if idx1==0:\n",
    "                        axd[plot_key].set_ylabel(explanation_method_mapper(explanation_method))\n",
    "                        \n",
    "                for idx2, insert_delete in enumerate(['insert', 'delete']):\n",
    "                    plot_key=f\"{class_idx}_{insert_delete}\"\n",
    "                    \n",
    "                    for idx3, explanation_methods_category in enumerate(explanation_method_main):\n",
    "                        for idx4, explanation_method in enumerate(explanation_methods_category):\n",
    "                            insertdelete_value=insertdelete_save_dict_backbone[explanation_method][adapt_path(path, list(insertdelete_save_dict_backbone[explanation_method].keys()))]                            \n",
    "\n",
    "                            axd[plot_key].plot(insertdelete_value[insert_delete][class_idx],\n",
    "                                                linestyle=[':','-.','-'][idx3],\n",
    "                                                c=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx4])\n",
    "                    #if idx1==0:\n",
    "                    #    axd[plot_key].set_ylim(0, 1.05)\n",
    "                    axd[plot_key].set_xlim(-2, 200)                                   \n",
    "                                \n",
    "                    axd[plot_key].xaxis.set_major_locator(MultipleLocator(28))\n",
    "                    axd[plot_key].xaxis.set_minor_locator(MultipleLocator(14))    \n",
    "                    axd[plot_key].xaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                    axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)  \n",
    "                    \n",
    "                    \n",
    "                    if axd[plot_key].get_ylim()[1]>0.5:\n",
    "                        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))\n",
    "                        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))\n",
    "                        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)                        \n",
    "                    else:\n",
    "                        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))\n",
    "                        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.05))\n",
    "                        axd[plot_key].yaxis.grid(True, which='major', linewidth=0.8, alpha=0.4)\n",
    "                        axd[plot_key].yaxis.grid(True, which='minor', linewidth=0.6, alpha=0.4)\n",
    "                    \n",
    "                    axd[plot_key].spines['right'].set_visible(False)\n",
    "                    axd[plot_key].spines['top'].set_visible(False)   \n",
    "                    \n",
    "                    for axis in ['top','bottom','left','right']:\n",
    "                        axd[plot_key].spines[axis].set_linewidth(2)\n",
    "                        \n",
    "                    if insert_delete==\"insert\":\n",
    "                        axd[plot_key].set_xticks([])   \n",
    "                        \n",
    "                    axd[plot_key].tick_params(axis = 'y', which = 'major', labelsize = 10, pad=-1)\n",
    "                    axd[plot_key].tick_params(axis = 'x', which = 'major', labelsize = 12)\n",
    "                        \n",
    "                    if idx1==0:\n",
    "                        axd[plot_key].set_ylabel(f\"Probability\")# ({insert_delete_mapper(insert_delete, verbose=False)})\")#, labelpad=-10)                        \n",
    "                    axd[plot_key].set_title(f\"{insert_delete_mapper(insert_delete, verbose=True)}\", pad=0, fontsize=15)\n",
    "                        \n",
    "            legend_elements = [Line2D([0], [0],\n",
    "                                      linestyle=[':','-.','-'][idx1],\n",
    "                                      color=plt.rcParams[\"axes.prop_cycle\"].by_key()['color'][idx2],\n",
    "                                      linewidth=3,\n",
    "                                      label=explanation_method_mapper(explanation_method))\n",
    "                                 for idx1, explanation_method_category in enumerate(explanation_method_main) for idx2, explanation_method in enumerate(explanation_method_category)]\n",
    "\n",
    "            fig.legend(handles=legend_elements, \n",
    "                        ncol=2, \n",
    "                        handlelength=3,\n",
    "                        handletextpad=0.6, \n",
    "                        columnspacing=1.5,\n",
    "                        fontsize=14,\n",
    "                        loc='lower center', bbox_to_anchor=(0.5, 0.05))\n",
    "            \n",
    "            fig.savefig(f\"results/plots/attribution_map_{dataset_name}/{backbone_type}_{dataset_idx:04d}.jpg\", bbox_inches='tight')            \n",
    "            #lt.close(fig)\n",
    "            #fig.tight_layout(pad=0.0, w_pad=0.0, h_pad=0.0)\n",
    "            #fig.subplots_adjust(hspace=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52bf2ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loaded_all.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbedf9b1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b011c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c15c321",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc66580b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c801383f",
   "metadata": {},
   "source": [
    "# Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de288612",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_methods_category_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3611f9a1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c86e10",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab4045b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889572fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d204495",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56fa172",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f85fef31",
   "metadata": {},
   "source": [
    "# Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1c2f956",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_sensitivity_table_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    auc_sensitivity_table_dict.setdefault(backbone_type, {})\n",
    "    \n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        auc_sensitivity_table_dict[backbone_type].setdefault(dataset_name, {})\n",
    "        # AUC mean, std\n",
    "        roc_auc_table_mean=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])[['target_auc', 'non-target_auc']].apply(lambda x: (x.mean(), 1.96*x.std()/((len(x))**(0.5))))\n",
    "        roc_auc_table_mean=roc_auc_table_mean.add_suffix('_mean')\n",
    "    \n",
    "        \n",
    "        # Sensitivity mean, std\n",
    "        sensitivity_table_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])[['target','non-target']].mean()\n",
    "        sensitivity_table_mean=sensitivity_table_mean.add_suffix('_mean').apply(lambda x: (x.mean(), 1.96*x.std()/((len(x))**(0.5))))\n",
    "        \n",
    "        \n",
    "        auc_sensitivity_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), \n",
    "                                         roc_auc_table_mean.loc['delete'].add_prefix('delete_'), \n",
    "                                         sensitivity_table_mean.loc['all'].add_prefix('sensitivity_')], axis=1)\n",
    "        for subset_mode in [\"main\", \"supple\"]:\n",
    "            auc_sensitivity_table_dict[backbone_type][dataset_name].setdefault(subset_mode, {})\n",
    "            if subset_mode==\"main\":\n",
    "                #auc_sensitivity_table_select=auc_sensitivity_table.loc[explanation_methods_main]\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_methods_main for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"main\"))\n",
    "            elif subset_mode==\"supple\":\n",
    "                auc_sensitivity_table_select=auc_sensitivity_table.loc[[j for i in explanation_methods_supple for j in i]]\n",
    "                auc_sensitivity_table_select.index=auc_sensitivity_table_select.index.map(lambda x: explanation_method_mapper(x, \"supple\"))\n",
    "            else:\n",
    "                raise\n",
    "            for target_non_target in [\"target\", \"non-target\"]:\n",
    "                auc_sensitivity_table_select_format=pd.concat([auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"insert_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"insert_{target_non_target}_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]) if x[f\"insert_{target_non_target}_auc_mean\"]>=(auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'][~auc_sensitivity_table_select[f'insert_{target_non_target}_auc_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"insert_target_auc_mean\"], x[f\"insert_{target_non_target}_auc_std\"]), axis=1),\n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"delete_{target_non_target}_auc_mean\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_{target_non_target}_auc_std\"]) if x[f\"delete_{target_non_target}_auc_mean\"]<=(auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'][~auc_sensitivity_table_select[f'delete_{target_non_target}_auc_mean'].isnull()]).min() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"delete_{target_non_target}_auc_mean\"], x[f\"delete_target_auc_std\"]), axis=1),                                        \n",
    "\n",
    "                                                               auc_sensitivity_table_select.apply(lambda x: '-' if np.isnan(x[f\"sensitivity_{target_non_target}_mean\"]) else\n",
    "                                                                                                            'bfs{:.3f} ({:.3f})bfe'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]) if x[f\"sensitivity_{target_non_target}_mean\"]>=(auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'][~auc_sensitivity_table_select[f'sensitivity_{target_non_target}_mean'].isnull()]).max() else\n",
    "                                                                                                            '{:.3f} ({:.3f})'.format(x[f\"sensitivity_{target_non_target}_mean\"], x[f\"sensitivity_{target_non_target}_std\"]), axis=1)  \n",
    "                                                               ], axis=1)\n",
    "                \n",
    "                #whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']        \n",
    "                auc_sensitivity_table_select_format.columns=pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Insertion (↑)\", \"Deletion (↓)\", \"Sensitivity-n (↑)\"]])\n",
    "                auc_sensitivity_table_dict[backbone_type][dataset_name][subset_mode][target_non_target]=\\\n",
    "                auc_sensitivity_table_select_format\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a5b581",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70c4d293",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1631ab46",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aa038a5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8387dbe7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd1db6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "list(map(lambda x: explanation_method_mapper(x, \"main\"), explanation_methods_main))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e67104dd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3476c315",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_to_latex(table_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aa9f229",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2e1fcfe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa73656",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c08680e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1421d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_result_dict[dataset_name][backbone_type]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e3b4d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_sensitivity_table.loc[[j for i in explanation_methods_main for j in i]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "258412eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_sensitivity_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a6fbec5",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "        \n",
    "    table_main=pd.concat(auc_sensitivity_table_select_format_list, axis=1)\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "631f5e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1424747",
   "metadata": {},
   "outputs": [],
   "source": [
    "#multicol=pd.MultiIndex.from_tuples([(dataset_name, metric) for dataset_name in [\"ImageNette\", \"MURA\"] for metric in [\"Insertion\", \"Deletion\", \"Sensitivity-n\"]])\n",
    "latex_output=table_main.to_latex()\n",
    "latex_output_split=latex_output.split('\\n')\n",
    "latex_output_split.insert(3, '\\\\cmidrule(lr){2-4} \\cmidrule(lr){5-7}')\n",
    "latex_output_split.insert(9, '\\\\midrule')\n",
    "latex_output_split.insert(9+5, '\\\\midrule')\n",
    "latex_output='\\n'.join(latex_output_split)\n",
    "latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "latex_output=latex_output.replace(\"bfe\",\"}\")\n",
    "\n",
    "\n",
    "print(latex_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be63280a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])['target','non-target'].apply(lambda x: len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6907269a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae39e556",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    whole_table_main_select_list=[]\n",
    "    for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "        roc_auc_table_mean=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean()\n",
    "        roc_auc_table_mean=roc_auc_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        roc_auc_table_std=pd.DataFrame(roc_auc_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].std()        \n",
    "        roc_auc_table_std=roc_auc_table_std.add_suffix('_std')\n",
    "        \n",
    "        \n",
    "        sensitivity_table_mean=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])['target','non-target'].mean()\n",
    "        sensitivity_table_mean=sensitivity_table_mean.add_suffix('_mean')\n",
    "        \n",
    "        \n",
    "        sensitivity_table_std=pd.DataFrame(sensitivity_result_dict[dataset_name][backbone_type])\\\n",
    "        .groupby(['num_included_players', 'explanation_method'])['target','non-target'].std()\n",
    "        sensitivity_table_std=sensitivity_table_std.add_suffix('_std')\n",
    "        \n",
    "        \n",
    "        whole_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), roc_auc_table_std.loc['insert'].add_prefix('insert_'), \n",
    "           \n",
    "           roc_auc_table_mean.loc['delete'].add_prefix('delete_'), roc_auc_table_std.loc['delete'].add_prefix('delete_'), \n",
    "           sensitivity_table_mean.loc['all'].add_prefix('sensitivity_'), sensitivity_table_std.loc['all'].add_prefix('sensitivity_')\n",
    "          \n",
    "          ],axis=1)       \n",
    "        \n",
    "        \n",
    "        whole_table_main=whole_table.loc[explanation_methods_main]\n",
    "        #print(whole_table)\n",
    "        #print(whole_table_main)\n",
    "        whole_table_main.index=whole_table_main.index.map(explanation_method_mapper)        \n",
    "        \n",
    "        \n",
    "        whole_table_main_select=pd.concat([\n",
    "        whole_table_main.apply(lambda x: '-' if np.isnan(x[\"insert_non-target_auc_mean\"]) else\n",
    "                          'bfs{:.3f} ({:.2f})bfe'.format(x[\"insert_non-target_auc_mean\"], x[\"insert_non-target_auc_std\"]) if x[\"insert_non-target_auc_mean\"]>=whole_table_main['insert_non-target_auc_mean'].max() else\n",
    "                          '{:.3f} ({:.2f})'.format(x[\"insert_non-target_auc_mean\"], x[\"insert_non-target_auc_std\"]), axis=1),\n",
    "                                      \n",
    "        whole_table_main.apply(lambda x: '-' if np.isnan(x[\"delete_non-target_auc_mean\"]) else\n",
    "                          'bfs{:.3f} ({:.2f})bfe'.format(x[\"delete_non-target_auc_mean\"], x[\"delete_non-target_auc_std\"]) if x[\"delete_non-target_auc_mean\"]<=whole_table_main['delete_non-target_auc_mean'].min() else\n",
    "                          '{:.3f} ({:.2f})'.format(x[\"delete_non-target_auc_mean\"], x[\"delete_non-target_auc_std\"]), axis=1),                                        \n",
    "\n",
    "        whole_table_main.apply(lambda x: '-' if np.isnan(x[\"sensitivity_non-target_mean\"]) else\n",
    "                          'bfs{:.3f} ({:.2f})bfe'.format(x[\"sensitivity_non-target_mean\"], x[\"sensitivity_non-target_std\"]) if x[\"sensitivity_non-target_mean\"]>=(whole_table_main['sensitivity_non-target_mean'][~whole_table_main['sensitivity_non-target_mean'].isnull()]).max() else\n",
    "                          '{:.3f} ({:.2f})'.format(x[\"sensitivity_non-target_mean\"], x[\"sensitivity_non-target_std\"]), axis=1)  \n",
    "        ], axis=1)\n",
    "        #whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']        \n",
    "        whole_table_main_select.columns=pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Insertion (↑)\", \"Deletion (↓)\", \"Sensitivity-n (↑)\"]])\n",
    "        \n",
    "        whole_table_main_select_list.append(whole_table_main_select)\n",
    "        \n",
    "        \n",
    "    table_main=pd.concat(whole_table_main_select_list, axis=1)\n",
    "        #print(whole_table_select)\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a47a078",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecab96d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#multicol=pd.MultiIndex.from_tuples([(dataset_name, metric) for dataset_name in [\"ImageNette\", \"MURA\"] for metric in [\"Insertion\", \"Deletion\", \"Sensitivity-n\"]])\n",
    "latex_output=table_main.to_latex()\n",
    "latex_output_split=latex_output.split('\\n')\n",
    "latex_output_split.insert(3, '\\\\cmidrule(lr){2-4} \\cmidrule(lr){5-7}')\n",
    "latex_output_split.insert(9, '\\\\midrule')\n",
    "latex_output_split.insert(9+5, '\\\\midrule')\n",
    "latex_output='\\n'.join(latex_output_split)\n",
    "latex_output=latex_output.replace(\"ViT Shapley\", \"\\\\textbf{ViT Shapley}\")\n",
    "latex_output=latex_output.replace(\"explanation\\_method\",\"\")\n",
    "latex_output=latex_output.replace(\"{l}\",\"{c}\")\n",
    "latex_output=latex_output.replace(\"bfs\",\"\\\\textbf{\")\n",
    "latex_output=latex_output.replace(\"bfe\",\"}\")\n",
    "\n",
    "\n",
    "print(latex_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd7f931",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9655eff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e454866",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c049fe26",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfe4717",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "796b2b4d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87b085e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6283eaa3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d31b2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.MultiIndex.from_tuples([(dataset_name, metric) for metric in [\"Insertion\", \"Deletion\", \"Sensitivity-n\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9639d35c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "887a5410",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "744b5f93",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pd.concat(whole_table_select_list, axis=1).to_latex())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aa4aafb",
   "metadata": {},
   "source": [
    "\n",
    "Given : $\\pi = 3.14$ , $\\alpha = \\frac{3\\pi}{4}\\, rad$\n",
    "$$\n",
    "\\omega = 2\\pi f \\\\\n",
    "f = \\frac{c}{\\lambda}\\\\\n",
    "\\lambda_0=\\theta^2+\\delta\\\\\n",
    "\\Delta\\lambda = \\frac{1}{\\lambda^2}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1674e9ab",
   "metadata": {},
   "source": [
    "$$\n",
    "\\\\begin{tabular}{lllllll}\\n\\\\toprule\\n{} &      Insertion &       Deletion &  Sensitivity-n &      Insertion &       Deletion &  Sensitivity-n \\\\\\\\\\nexplanation\\\\_method    &                &                &                &                &                &                \\\\\\\\\\n\\\\midrule\\nLRP                   &  0.967 (0.062) &  0.778 (0.224) &  0.705 (0.247) &  0.900 (0.154) &  0.551 (0.217) &  0.646 (0.262) \\\\\\\\\\nattention\\\\_last        &  0.962 (0.066) &  0.791 (0.211) &  0.694 (0.245) &  0.890 (0.167) &  0.592 (0.218) &  0.635 (0.266) \\\\\\\\\\nattention\\\\_rollout     &  0.938 (0.084) &  0.880 (0.167) &  0.704 (0.247) &  0.845 (0.177) &  0.692 (0.229) &  0.618 (0.264) \\\\\\\\\\ngradcam               &  0.937 (0.098) &  0.948 (0.096) &  0.656 (0.235) &  0.843 (0.193) &  0.835 (0.180) &  0.580 (0.258) \\\\\\\\\\ngradcamgithub         &  0.916 (0.097) &  0.938 (0.129) &  0.680 (0.240) &  0.900 (0.140) &  0.676 (0.249) &  0.631 (0.258) \\\\\\\\\\nigembedding           &  0.968 (0.058) &  0.929 (0.128) &  0.403 (0.400) &  0.897 (0.157) &  0.796 (0.245) &  0.201 (0.362) \\\\\\\\\\nigpixel               &  0.968 (0.058) &  0.929 (0.128) &  0.403 (0.400) &  0.897 (0.157) &  0.796 (0.245) &  0.201 (0.362) \\\\\\\\\\nleaveoneoutclassifier &  0.941 (0.079) &  0.879 (0.185) &  0.140 (0.649) &  0.926 (0.127) &  0.691 (0.283) &  0.308 (0.523) \\\\\\\\\\nleaveoneoutsurrogate  &  0.951 (0.060) &  0.830 (0.258) &  0.351 (0.585) &  0.972 (0.042) &  0.467 (0.322) &  0.709 (0.209) \\\\\\\\\\nours                  &  0.985 (0.039) &  0.688 (0.222) &  0.711 (0.246) &  0.971 (0.036) &  0.307 (0.213) &  0.707 (0.208) \\\\\\\\\\nrandom                &  0.951 (0.077) &  0.951 (0.078) &              - &  0.848 (0.162) &  0.847 (0.161) &              - \\\\\\\\\\nriseclassifier        &  0.906 (0.216) &  0.825 (0.268) &  0.704 (0.247) &  0.957 (0.064) &  0.573 (0.292) &  0.618 (0.263) \\\\\\\\\\nrisesurrogate         &  0.924 (0.200) &  0.745 (0.305) &  0.704 (0.247) &  0.978 (0.018) &  0.341 (0.268) &  0.619 (0.263) \\\\\\\\\\nsgembedding           &  0.946 (0.090) &  0.942 (0.093) &  0.703 (0.247) &  0.870 (0.169) &  0.813 (0.176) &  0.617 (0.263) \\\\\\\\\\nsgpixel               &  0.960 (0.078) &  0.778 (0.219) &  0.706 (0.248) &  0.873 (0.170) &  0.634 (0.226) &  0.618 (0.265) \\\\\\\\\\nvanillaembedding      &  0.949 (0.072) &  0.806 (0.209) &  0.703 (0.248) &  0.890 (0.163) &  0.537 (0.223) &  0.629 (0.265) \\\\\\\\\\nvanillapixel          &  0.938 (0.081) &  0.859 (0.177) &  0.700 (0.247) &  0.890 (0.158) &  0.561 (0.228) &  0.627 (0.265) \\\\\\\\\\nvargradembedding      &  0.948 (0.086) &  0.947 (0.088) &  0.700 (0.246) &  0.857 (0.178) &  0.823 (0.179) &  0.615 (0.263) \\\\\\\\\\nvargradpixel          &  0.958 (0.078) &  0.794 (0.210) &  0.682 (0.242) &  0.871 (0.167) &  0.660 (0.218) &  0.577 (0.246) \\\\\\\\\\n\\\\bottomrule\\n\\\\end{tabular}\\n\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da44f066",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "950ec55c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_table_mean.loc['all']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce12bc61",
   "metadata": {},
   "outputs": [],
   "source": [
    "whole_table=pd.concat([roc_auc_table_mean.loc['insert'].add_prefix('insert_'), roc_auc_table_std.loc['insert'].add_prefix('insert_'), \n",
    "           \n",
    "           roc_auc_table_mean.loc['delete'].add_prefix('delete_'), roc_auc_table_std.loc['delete'].add_prefix('delete_'), \n",
    "           sensitivity_table_mean.loc['all'].add_prefix('sensitivity_'), sensitivity_table_std.loc['all'].add_prefix('sensitivity_')\n",
    "          \n",
    "          ],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71e25f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "whole_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e85845d7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "whole_table_select=pd.concat([whole_table.apply(lambda x: '{:.3f} ({:.3f})'.format(x[\"insert_target_auc_mean\"],\n",
    "                                               x[\"insert_target_auc_std\"],\n",
    "                                              ) if ~np.isnan(x[\"insert_target_auc_mean\"]) else '-', axis=1),\n",
    "whole_table.apply(lambda x: '{:.3f} ({:.3f})'.format(x[\"delete_target_auc_mean\"],\n",
    "                                               x[\"delete_target_auc_std\"],\n",
    "                                              ) if ~np.isnan(x[\"delete_target_auc_mean\"]) else '-', axis=1),\n",
    "whole_table.apply(lambda x: '{:.3f} ({:.3f})'.format(x[\"sensitivity_target_mean\"],\n",
    "                                               x[\"sensitivity_target_std\"],\n",
    "                                              ) if ~np.isnan(x[\"sensitivity_target_mean\"]) else '-', axis=1)  \n",
    "], axis=1)\n",
    "whole_table_select.columns=['Insertion', 'Deletion', 'Sensitivity-n']\n",
    "whole_table_select"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0eb7bad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9e62da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf5bddfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_table_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41c234bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_table_mean.rename(columns={'target_auc': 'target_auc_mean', 'non-target_auc':'non-target_auc_mean'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f511c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_table_mean.loc['insert']['target_auc'].map(lambda x: '{:.3f}'.format(x))+\\\n",
    "roc_auc_table_std.loc['insert']['target_auc'].map(lambda x: '{:.3f}'.format(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e45ada7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.concat([roc_auc_table_mean.loc['insert'], roc_auc_table_mean.loc['delete']],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "817d103d",
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "965f17ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "438e3e90",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1742d5a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dcddbf7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d8d8cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63624cf7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "995a9b54",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1=pd.DataFrame(roc_auc_df_dict[\"ImageNette\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674348b0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output=f'\\\\begin{{table}} \\n\\\n",
    "\\\\caption{{Results.}} \\n\\\n",
    "\\\\label{{tab:metrics}} \\n\\\n",
    "\\\\begin{{center}} \\n\\\n",
    "\\\\begin{{small}} \\n\\\n",
    "\\\\begin{{tabular}}{{lcccccc}} \\n\\\n",
    "\\\\toprule \\n\\\n",
    " & \\\\multicolumn{{3}}{{c}}{{ImageNette}} & \\\\multicolumn{{3}}{{c}}{{MURA}} \\\\ \\n\\\n",
    "\\\\cmidrule(lr){{2-4}} \\\\cmidrule(lr){{5-7}} \\n\\\n",
    " & Insertion & Deletion & Sensitivity-$n$ & Insertion & Deletion & Sensitivity-$n$ \\\\ \\n\\\n",
    "\\\\midrule \\n\\\n",
    "Attention raw & {df1.loc[\"insert\"][\"target_auc\"][\"attention_last\"]:.3f} & {df1.loc[\"delete\"][\"target_auc\"][\"attention_last\"]:.3f} & \\\n",
    "1 & 1 & 1\\\\ \\n\\\n",
    "\\\n",
    "Attention rollout \\\\ \\n\\\n",
    "\\\\midrule \\n\\\n",
    "GradCAM \\\\ \\n\\\n",
    "Integrated Gradients \\\\ \\n\\\n",
    "SmoothGrad \\\\ \\n\\\n",
    "LRP \\\\ \\n\\\n",
    "\\\\midrule \\n\\\n",
    "Leave-one-out \\\\ \\n\\\n",
    "RISE \\\\ \\n\\\n",
    "\\\\textbf{{ViT Shapley}} \\\\ \\n\\\n",
    "\\\\bottomrule \\n\\\n",
    "\\\\end{{tabular}} \\n\\\n",
    "\\\\end{{small}} \\n\\\n",
    "\\\\end{{center}} \\n\\\n",
    "\\\\end{{table}}'\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd194fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_table(df1, df2):\n",
    "    output=f'\\\n",
    "\\\\begin{bc}table{bc}\\n\\\n",
    "\\\\caption{bc}Results.{bc}\\n\\\n",
    "\\\\label{bc}tab:metrics{bc}\\n\\\n",
    "\\\\begin{bc}center{bc}\\n\\\n",
    "\\\\begin{bc}small{bc}\\n\\\n",
    "\\\\begin{bc}tabular{bc}{bc}lcccccc{bc}\\n\\\n",
    "\\\\toprule\\n\\\n",
    " & \\\\multicolumn{bc}3{bc}{bc}c{bc}{bc}ImageNette & \\\\multicolumn{3}{c}{MURA} \\\\\\n\\\n",
    "\\\\cmidrule(lr){2-4} \\cmidrule(lr){5-7}\\n\\\n",
    " & Insertion & Deletion & Sensitivity-$n$ & Insertion & Deletion & Sensitivity-$n$ \\\\\\n\\\n",
    "\\\\midrule\\n\\\n",
    "Attention raw & {df1[\"insert\"][\"attention_last\"]} & 1 & 1 & 1 & 1\\\\\\n\\\n",
    "Attention rollout \\\\ \\n\\\n",
    "\\\\midrule \\n\\\n",
    "GradCAM \\\\ \\n\\\n",
    "Integrated Gradients \\\\ \\n\\\n",
    "SmoothGrad \\\\ \\n\\\n",
    "LRP \\\\ \\n\\\n",
    "\\\\midrule\\n\\\n",
    "Leave-one-out \\\\ \\n\\\n",
    "RISE \\\\ \\n\\\n",
    "\\\\textbf{ViT Shapley} \\\\ \\n\\\n",
    "\\\\bottomrule\\ \\n\\\n",
    "\\\\end{tabular}\\ \\n\\\n",
    "\\\\end{small}\\ \\n\\\n",
    "\\\\end{center}\\ \\n\\\n",
    "\\\\end{table}\\ \\n\\\n",
    "'\n",
    "    return output\n",
    "a=make_table(pd.DataFrame(roc_auc_df_dict[\"ImageNette\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean(), \n",
    "pd.DataFrame(roc_auc_df_dict[\"MURA\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean())\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527e46d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\\\\\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5748c4da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49bf7e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(roc_auc_df_dict[\"ImageNette\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc1869f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(roc_auc_df_dict[\"MURA\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc', 'non-target_auc'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb736550",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(roc_auc_df_dict[\"MURA\"][\"vit_base_patch16_224\"])\\\n",
    ".groupby(['metric_mode', 'explanation_method'])['target_auc'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc050082",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in [\"MURA\"]:\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        sensitivity_dict_list=[]\n",
    "        for explanation_method, sensitivity_save_dict in data_loaded_all[\"4_3_sensitivity\"][dataset_name][backbone_type].items():\n",
    "            for path, sensitivity in sensitivity_save_dict.items():\n",
    "                classifier_prob_data=data_loaded_all['4_0_classifier_evaluate'][dataset_name][backbone_type]                \n",
    "                classifier_prob=[float(value.replace(\"[\",\"\").replace(\"]\",\"\")) for value in classifier_prob_data.loc[adapt_path(path, classifier_prob_data.index[0])][\"prob\"].split()]\n",
    "                \n",
    "                if len(classifier_prob)==1: # MURA\n",
    "                    classifier_prob_argmax=0\n",
    "                    if classifier_prob[0]<0.5: # Ignore negative(=normal) images\n",
    "                        continue\n",
    "                else: # ImageNette\n",
    "                    classifier_prob_argmax=np.argmax(classifier_prob)               \n",
    "                \n",
    "                for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "                    if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                        sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'path': path,\n",
    "                                                      'num_included_players': num_included_players,\n",
    "                                                      'target': sensitivity[num_included_players][classifier_prob_argmax],\n",
    "                                                      'non-target': None\n",
    "                                                     })\n",
    "\n",
    "                    else:\n",
    "                        sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                      'path': path,\n",
    "                                                      'num_included_players': num_included_players,                                                      \n",
    "                                                      'target': sensitivity[num_included_players][classifier_prob_argmax],\n",
    "                                                      'non-target': sensitivity[num_included_players][np.arange(len(sensitivity_save_dict[path][num_included_players]))!=classifier_prob_argmax].mean()\n",
    "                                                     })                    \n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        print(pd.DataFrame(sensitivity_dict_list).groupby(['explanation_method']).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c8f9cb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(sensitivity_dict_list).columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d711d519",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for dataset_name in [\"ImageNette\", \"MURA\"]:\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        sensitivity_dict_list=[]\n",
    "        for explanation_method, sensitivity_save_dict in data_loaded_all[\"4_3_sensitivity\"][dataset_name][backbone_type].items():\n",
    "            for path, sensitivity in sensitivity_save_dict.items():\n",
    "                classifier_prob_data=data_loaded_all['4_0_classifier_evaluate'][dataset_name][backbone_type]                \n",
    "                classifier_prob=[float(value.replace(\"[\",\"\").replace(\"]\",\"\")) for value in classifier_prob_data.loc[adapt_path(path, classifier_prob_data.index[0])][\"prob\"].split()]\n",
    "                \n",
    "                if len(classifier_prob)==1:\n",
    "                    classifier_prob_argmax=0\n",
    "                else:\n",
    "                    classifier_prob_argmax=np.argmax(classifier_prob)                \n",
    "                \n",
    "                \n",
    "                if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                    sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                  'path': path,\n",
    "                                                  'target': sensitivity['sensitivity'][classifier_prob_argmax],\n",
    "                                                  'non-target': None\n",
    "                                                 })\n",
    "                    \n",
    "                else:\n",
    "                    sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                  'path': path,\n",
    "                                                  'target': sensitivity['sensitivity'][classifier_prob_argmax],\n",
    "                                                  'non-target': sensitivity['sensitivity'][np.arange(len(sensitivity_save_dict[path]['sensitivity']))!=classifier_prob_argmax].mean()\n",
    "                                                 })                    \n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        print(pd.DataFrame(sensitivity_dict_list).groupby(['explanation_method']).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e800bc89",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset_name in [\"MURA\"]:\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        sensitivity_dict_list=[]\n",
    "        for explanation_method, sensitivity_save_dict in data_loaded_all[\"4_3_sensitivity\"][dataset_name][backbone_type].items():\n",
    "            for path, sensitivity in sensitivity_save_dict.items():\n",
    "                classifier_prob_data=data_loaded_all['4_0_classifier_evaluate'][dataset_name][backbone_type]\n",
    "                print(classifier_prob_data)\n",
    "                if explanation_method in [\"attention_rollout\", \"attention_last\"]:\n",
    "                    sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                  'path': path,\n",
    "                                                  'target': sensitivity['sensitivity'][classifier_prob_argmax],\n",
    "                                                  'non-target': None\n",
    "                                                 })\n",
    "                    \n",
    "                else:\n",
    "                    sensitivity_dict_list.append({'explanation_method': explanation_method,\n",
    "                                                  'path': path,\n",
    "                                                  'target': sensitivity['sensitivity'][classifier_prob_argmax],\n",
    "                                                  'non-target': sensitivity['sensitivity'][np.arange(len(sensitivity_save_dict[path]['sensitivity']))!=classifier_prob_argmax].mean()\n",
    "                                                 })                    \n",
    "        print(f'{dataset_name}   {backbone_type}')\n",
    "        print(pd.DataFrame(sensitivity_dict_list).groupby(['explanation_method']).mean())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vitshapley",
   "language": "python",
   "name": "vitshapley"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
