{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a263fe4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.getcwd())\n",
    "os.chdir('../')\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86f348f8",
   "metadata": {},
   "source": [
    "# config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c897bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from collections import OrderedDict\n",
    "import copy\n",
    "import pickle\n",
    "import time\n",
    "from scipy import stats\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from vit_shapley.datamodules.ImageNette_datamodule import ImageNetteDataModule\n",
    "from vit_shapley.datamodules.MURA_datamodule import MURADataModule\n",
    "\n",
    "from vit_shapley.modules.classifier import Classifier\n",
    "from vit_shapley.modules.classifier_masked import ClassifierMasked\n",
    "from vit_shapley.modules.surrogate import Surrogate\n",
    "from vit_shapley.modules.explainer import Explainer\n",
    "\n",
    "from vit_shapley.config import ex\n",
    "from vit_shapley.config import config, env_username, dataset_ImageNette, dataset_MURA\n",
    "_config=config()\n",
    "\n",
    "dataset_split=\"test\"\n",
    "parallel_mode = (0, 4)\n",
    "backbone_to_use=[\"vit_base_patch16_224\"]\n",
    "_config.update(dataset_ImageNette())\n",
    "evaluation_stage=[\"1_classifier_evaluate\",\n",
    "                  \"2_surrogate_evaluate\",\n",
    "                  \"3_explanation_generate\",\n",
    "                  \"4_insert_delete\",\n",
    "                  \"5_sensitivity\",\n",
    "                  \"6_noretraining\",\n",
    "                  \"7_classifiermasked\",\n",
    "                  \"8_elapsedtime\",\n",
    "                  \"9_estimationerror\"][-1]\n",
    "\n",
    "_config.update(env_username()); _config.update({'gpus_classifier':[1,],\n",
    "                                                'gpus_surrogate':[1,],\n",
    "                                                'gpus_explainer':[1,]})\n",
    "\n",
    "_config.update({'classifier_backbone_type': None,\n",
    "                'classifier_download_weight': False,\n",
    "                'classifier_load_path': None})\n",
    "_config.update({'classifier_masked_mask_location': \"pre-softmax\",\n",
    "                'classifier_enable_pos_embed': True,\n",
    "                })\n",
    "_config.update({'surrogate_mask_location': \"pre-softmax\"})\n",
    "_config.update({'surrogate_backbone_type': None,\n",
    "                'surrogate_download_weight': False,\n",
    "                'surrogate_load_path': None})\n",
    "_config.update({'explainer_num_mask_samples': 2,\n",
    "                'explainer_paired_mask_samples': True})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc42bddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "!gpustat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6058196f",
   "metadata": {},
   "outputs": [],
   "source": [
    "if _config[\"datasets\"]==\"ImageNette\":\n",
    "    backbone_type_config_dict_=OrderedDict({\n",
    "        \"vit_small_patch16_224\":{\n",
    "            \"classifier_path\": \"results/wandb_transformer_interpretability_project/1yndrggu/checkpoints/epoch=14-step=2204.ckpt\",\n",
    "            \"classifier_masked_path\": \"results/wandb_transformer_interpretability_project/fdm70w72/checkpoints/epoch=19-step=2939.ckpt\",\n",
    "            \"surrogate_path\":{\n",
    "                \"pre-softmax\": \"results/wandb_transformer_interpretability_project/3lfv4nmn/checkpoints/epoch=39-step=5879.ckpt\"\n",
    "            },\n",
    "            \"explainer_path\":\"results/wandb_transformer_interpretability_project/3biv2s85/checkpoints/epoch=60-step=9027.ckpt\"\n",
    "\n",
    "        },\n",
    "        \"deit_small_patch16_224\":{\n",
    "        },\n",
    "        \"vit_base_patch16_224\":{\n",
    "            \"classifier_path\": \"results/wandb_transformer_interpretability_project/2rq1issn/checkpoints/epoch=16-step=2498.ckpt\",\n",
    "            \"classifier_masked_path\": \"results/wandb_transformer_interpretability_project/x59c992d/checkpoints/epoch=21-step=3233.ckpt\",\n",
    "            \"surrogate_path\":{\n",
    "                #\"original\": \"results/wandb_transformer_interpretability_project/2rq1issn/checkpoints/epoch=16-step=2498.ckpt\",\n",
    "                \"pre-softmax\": \"results/wandb_transformer_interpretability_project/3i6zzjnp/checkpoints/epoch=38-step=5732.ckpt\",\n",
    "                #\"zero-input\": \"results/wandb_transformer_interpretability_project/zyybgzcm/checkpoints/epoch=22-step=3380.ckpt\",\n",
    "                #\"zero-embedding\": \"results/wandb_transformer_interpretability_project/1gi5gmrm/checkpoints/epoch=36-step=5438.ckpt\"\n",
    "                },\n",
    "            \"explainer_path\": \"results/wandb_transformer_interpretability_project/3ty85eft/checkpoints/epoch=83-step=12431.ckpt\"\n",
    "        },\n",
    "        \"deit_base_patch16_224\":{\n",
    "\n",
    "        }\n",
    "    })    \n",
    "elif _config[\"datasets\"]==\"MURA\":\n",
    "    backbone_type_config_dict_=OrderedDict({\n",
    "        \"vit_small_patch16_224\":{\n",
    "\n",
    "        },\n",
    "        \"deit_small_patch16_224\":{\n",
    "        },\n",
    "        \"vit_base_patch16_224\":{\n",
    "            \"classifier_path\":\"results/wandb_transformer_interpretability_project/1u2xgwks/checkpoints/epoch=15-step=8255.ckpt\",\n",
    "            \"surrogate_path\": {\n",
    "                #\"original\": \"results/wandb_transformer_interpretability_project/1u2xgwks/checkpoints/epoch=15-step=8255.ckpt\",\n",
    "                \"pre-softmax\": \"results/wandb_transformer_interpretability_project/22ompjqu/checkpoints/epoch=47-step=24767.ckpt\",\n",
    "                #\"zero-input\": \"results/wandb_transformer_interpretability_project/2z2qs6t0/checkpoints/epoch=44-step=23219.ckpt\",\n",
    "                #\"zero-embedding\": \"results/wandb_transformer_interpretability_project/1pbmwnvb/checkpoints/epoch=45-step=23735.ckpt\"\n",
    "            },\n",
    "            \"explainer_path\":\"results/wandb_transformer_interpretability_project/1dmhcwej/checkpoints/epoch=93-step=48597.ckpt\"\n",
    "        },\n",
    "        \"deit_base_patch16_224\":{\n",
    "\n",
    "        }\n",
    "    })"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8426be8",
   "metadata": {},
   "source": [
    "# Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54d6d39a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,\n",
    "                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        num_players: the number of players in the coalitional game\n",
    "        num_mask_samples: the number of masks to generate\n",
    "        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.\n",
    "        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')\n",
    "        random_state: random generator\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor of shape\n",
    "        (num_masks, num_players) if num_masks is int\n",
    "        (num_players) if num_masks is None\n",
    "\n",
    "    \"\"\"\n",
    "    random_state = random_state or np.random\n",
    "\n",
    "    num_samples_ = num_mask_samples or 1\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        assert num_samples_ % 2 == 0, \"'num_samples' must be a multiple of 2 if 'paired' is True\"\n",
    "        num_samples_ = num_samples_ // 2\n",
    "    else:\n",
    "        num_samples_ = num_samples_\n",
    "\n",
    "    if mode == 'uniform':\n",
    "        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')\n",
    "    elif mode == 'shapley':\n",
    "        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))\n",
    "        probs = probs / probs.sum()\n",
    "        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(\n",
    "            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')\n",
    "    else:\n",
    "        raise ValueError(\"'mode' must be 'random' or 'shapley'\")\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)\n",
    "\n",
    "    if num_mask_samples is None:\n",
    "        masks = masks.squeeze(0)\n",
    "        return masks  # (num_masks)\n",
    "    else:\n",
    "        return masks  # (num_samples, num_masks)\n",
    "\n",
    "def set_datamodule(datasets,\n",
    "                   dataset_location,\n",
    "                   explanation_location_train,\n",
    "                   explanation_mask_amount_train,\n",
    "                   explanation_mask_ascending_train,\n",
    "                   \n",
    "                   explanation_location_val,\n",
    "                   explanation_mask_amount_val,\n",
    "                   explanation_mask_ascending_val,                   \n",
    "                   \n",
    "                   explanation_location_test,\n",
    "                   explanation_mask_amount_test,\n",
    "                   explanation_mask_ascending_test,                   \n",
    "                   \n",
    "                   transforms_train,\n",
    "                   transforms_val,\n",
    "                   transforms_test,\n",
    "                   num_workers,\n",
    "                   per_gpu_batch_size,\n",
    "                   test_data_split):\n",
    "    dataset_parameters = {\n",
    "        \"dataset_location\": dataset_location,\n",
    "        \"explanation_location_train\": explanation_location_train,\n",
    "        \"explanation_mask_amount_train\": explanation_mask_amount_train,\n",
    "        \"explanation_mask_ascending_train\": explanation_mask_ascending_train,\n",
    "        \n",
    "        \"explanation_location_val\": explanation_location_val,\n",
    "        \"explanation_mask_amount_val\": explanation_mask_amount_val,\n",
    "        \"explanation_mask_ascending_val\": explanation_mask_ascending_val,\n",
    "        \n",
    "        \"explanation_location_test\": explanation_location_test,\n",
    "        \"explanation_mask_amount_test\": explanation_mask_amount_test,\n",
    "        \"explanation_mask_ascending_test\": explanation_mask_ascending_test,        \n",
    "        \n",
    "        \"transforms_train\": transforms_train,\n",
    "        \"transforms_val\": transforms_val,\n",
    "        \"transforms_test\": transforms_test,\n",
    "        \"num_workers\": num_workers,\n",
    "        \"per_gpu_batch_size\": per_gpu_batch_size,\n",
    "        \"test_data_split\": test_data_split\n",
    "    }\n",
    "\n",
    "    if datasets == \"CheXpert\":\n",
    "        datamodule = CheXpertDataModule(**dataset_parameters)\n",
    "    elif datasets == \"MIMIC\":\n",
    "        datamodule = MIMICDataModule(**dataset_parameters)\n",
    "    elif datasets == \"MURA\":\n",
    "        datamodule = MURADataModule(**dataset_parameters)\n",
    "    elif datasets == \"ImageNette\":\n",
    "        datamodule = ImageNetteDataModule(**dataset_parameters)\n",
    "    else:\n",
    "        ValueError(\"Invalid 'datasets' configuration\")\n",
    "    return datamodule\n",
    "\n",
    "datamodule = set_datamodule(datasets=_config[\"datasets\"],\n",
    "                            dataset_location=_config[\"dataset_location\"],\n",
    "\n",
    "                            explanation_location_train=_config[\"explanation_location_train\"],\n",
    "                            explanation_mask_amount_train=_config[\"explanation_mask_amount_train\"],\n",
    "                            explanation_mask_ascending_train=_config[\"explanation_mask_ascending_train\"],\n",
    "\n",
    "                            explanation_location_val=_config[\"explanation_location_val\"],\n",
    "                            explanation_mask_amount_val=_config[\"explanation_mask_amount_val\"],\n",
    "                            explanation_mask_ascending_val=_config[\"explanation_mask_ascending_val\"],\n",
    "\n",
    "                            explanation_location_test=_config[\"explanation_location_test\"],\n",
    "                            explanation_mask_amount_test=_config[\"explanation_mask_amount_test\"],\n",
    "                            explanation_mask_ascending_test=_config[\"explanation_mask_ascending_test\"],                            \n",
    "\n",
    "                            transforms_train=_config[\"transforms_train\"],\n",
    "                            transforms_val=_config[\"transforms_val\"],\n",
    "                            transforms_test=_config[\"transforms_test\"],\n",
    "                            num_workers=_config[\"num_workers\"],\n",
    "                            per_gpu_batch_size=_config[\"per_gpu_batch_size\"],\n",
    "                            test_data_split=_config[\"test_data_split\"])\n",
    "\n",
    "# The batch for training classifier consists of images and labels, but the batch for training explainer consists of images and masks.\n",
    "# The masks are generated to follow the Shapley distribution.\n",
    "\"\"\"\n",
    "original_getitem = copy.deepcopy(datamodule.dataset_cls.__getitem__)\n",
    "def __getitem__(self, idx):\n",
    "    if self.split == 'train':\n",
    "        masks = generate_mask(num_players=surrogate.num_players,\n",
    "                              num_mask_samples=_config[\"explainer_num_mask_samples\"],\n",
    "                              paired_mask_samples=_config[\"explainer_paired_mask_samples\"], mode='shapley')\n",
    "    elif self.split == 'val' or self.split == 'test':\n",
    "        # get cached if available\n",
    "        if not hasattr(self, \"masks_cached\"):\n",
    "            self.masks_cached = {}\n",
    "        masks = self.masks_cached.setdefault(idx, generate_mask(num_players=surrogate.num_players,\n",
    "                                                                num_mask_samples=_config[\n",
    "                                                                    \"explainer_num_mask_samples\"],\n",
    "                                                                paired_mask_samples=_config[\n",
    "                                                                    \"explainer_paired_mask_samples\"],\n",
    "                                                                mode='shapley'))\n",
    "    else:\n",
    "        raise ValueError(\"'split' variable must be train, val or test.\")\n",
    "    return {\"images\": original_getitem(self, idx)[\"images\"],\n",
    "            \"labels\": original_getitem(self, idx)[\"labels\"],\n",
    "            \"masks\": masks}\n",
    "datamodule.dataset_cls.__getitem__ = __getitem__\n",
    "\"\"\"\n",
    "\n",
    "datamodule.set_train_dataset()\n",
    "datamodule.set_val_dataset()\n",
    "datamodule.set_test_dataset()\n",
    "\n",
    "train_dataset=datamodule.train_dataset\n",
    "val_dataset=datamodule.val_dataset\n",
    "test_dataset=datamodule.test_dataset\n",
    "\n",
    "dset=test_dataset\n",
    "\n",
    "if dataset_split==\"train\":\n",
    "    dset.data = train_dataset.data\n",
    "elif dataset_split==\"val\":\n",
    "    dset.data = val_dataset.data     \n",
    "elif dataset_split==\"test\": \n",
    "    dset.data = test_dataset.data\n",
    "else:\n",
    "    raise\n",
    "\n",
    "labels = np.array([i['label'] for i in dset.data])\n",
    "num_classes = labels.max() + 1\n",
    "\n",
    "images_idx_list = [np.where(labels == category)[0] for category in range(num_classes)]\n",
    "\n",
    "images_idx=[]\n",
    "for classidx in range(4,4+int(10/len(images_idx_list))):\n",
    "    images_idx+=[category_idx[classidx] for category_idx in images_idx_list]\n",
    "\n",
    "xy=[dset[idx] for idx in images_idx]\n",
    "x, y = zip(*[(i['images'], i['labels']) for i in xy])\n",
    "x = torch.stack(x)\n",
    "y_labels=[dset.labels[i] for i in y]\n",
    "\n",
    "\n",
    "if _config[\"datasets\"]==\"ImageNette\":\n",
    "    label_name_list=['Cassette player', \n",
    "                      'Garbage truck', \n",
    "                      'Tench', \n",
    "                      'English springer', \n",
    "                      'Church', \n",
    "                      'Parachute', \n",
    "                      'French horn', \n",
    "                      'Chain saw', \n",
    "                      'Golf ball', \n",
    "                      'Gas pump']\n",
    "    \n",
    "elif _config[\"datasets\"]==\"MURA\":\n",
    "    label_name_list=[\"Normal\", \"Abnormal\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7380660e",
   "metadata": {},
   "source": [
    "# Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbc25f70",
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone_type_config_dict = OrderedDict()\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict_.items()):\n",
    "    if backbone_type in backbone_to_use:\n",
    "        print(backbone_type)\n",
    "        backbone_type_config_dict[backbone_type]=backbone_type_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13aafbd0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "classifier_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    classifier_dict[backbone_type] = Classifier(backbone_type=backbone_type,\n",
    "                                               download_weight=_config['classifier_download_weight'],\n",
    "                                               load_path=backbone_type_config[\"classifier_path\"],\n",
    "                                               target_type=_config[\"target_type\"],\n",
    "                                               output_dim=_config[\"output_dim\"],\n",
    "                                               enable_pos_embed=_config[\"classifier_enable_pos_embed\"],\n",
    "\n",
    "                                               checkpoint_metric=None,\n",
    "                                               loss_weight=None,\n",
    "                                               optim_type=None,\n",
    "                                               learning_rate=None,\n",
    "                                               weight_decay=None,\n",
    "                                               decay_power=None,\n",
    "                                               warmup_steps=None).to(_config[\"gpus_classifier\"][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3016a85",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_dict_ = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    classifier_dict_[backbone_type] = Surrogate(mask_location=_config[\"surrogate_mask_location\"],\n",
    "                                                   backbone_type=backbone_type,\n",
    "                                                   download_weight=_config['classifier_download_weight'],\n",
    "                                                   load_path=backbone_type_config[\"classifier_path\"],\n",
    "                                                   target_type=_config[\"target_type\"],\n",
    "                                                   output_dim=_config[\"output_dim\"],\n",
    "\n",
    "                                                   target_model=None,\n",
    "                                                   checkpoint_metric=None,\n",
    "                                                   optim_type=None,\n",
    "                                                   learning_rate=None,\n",
    "                                                   weight_decay=None,\n",
    "                                                   decay_power=None,\n",
    "                                                   warmup_steps=None).to(_config[\"gpus_classifier\"][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05bf9de",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"7_classifiermasked\":\n",
    "    classifier_masked_dict = OrderedDict()\n",
    "\n",
    "    for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        classifier_masked_dict[backbone_type] = ClassifierMasked(mask_location=_config[\"classifier_masked_mask_location\"],\n",
    "                                                               backbone_type=backbone_type,\n",
    "                                                               download_weight=_config['classifier_download_weight'],\n",
    "                                                               load_path=backbone_type_config[\"classifier_masked_path\"],\n",
    "                                                               target_type=_config[\"target_type\"],\n",
    "                                                               output_dim=_config[\"output_dim\"],\n",
    "\n",
    "                                                               checkpoint_metric=None,\n",
    "                                                               loss_weight=None,                                                             \n",
    "                                                               optim_type=None,\n",
    "                                                               learning_rate=None,\n",
    "                                                               weight_decay=None,\n",
    "                                                               decay_power=None,\n",
    "                                                               warmup_steps=None).to(_config[\"gpus_classifier\"][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1a294aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "surrogate_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    mask_method_dict = OrderedDict()\n",
    "    for mask_location in backbone_type_config[\"surrogate_path\"].keys():\n",
    "        mask_method_dict[mask_location] = Surrogate(mask_location=mask_location if mask_location!=\"original\" else \"pre-softmax\",\n",
    "                                          backbone_type=backbone_type,\n",
    "                                          download_weight=_config['surrogate_download_weight'],\n",
    "                                          load_path=backbone_type_config[\"surrogate_path\"][mask_location],\n",
    "                                          target_type=_config[\"target_type\"],\n",
    "                                          output_dim=_config[\"output_dim\"],\n",
    "\n",
    "                                          target_model=None,\n",
    "                                          checkpoint_metric=None,\n",
    "                                          optim_type=None,\n",
    "                                          learning_rate=None,\n",
    "                                          weight_decay=None,\n",
    "                                          decay_power=None,\n",
    "                                          warmup_steps=None).to(_config[\"gpus_surrogate\"][idx])\n",
    "    surrogate_dict[backbone_type]=mask_method_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8414a2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vitmedical.modules.explainer import Explainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9276fe1b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "_config.update({'explainer_normalization': \"additive\",\n",
    "                'explainer_activation': \"tanh\",\n",
    "                'explainer_link': 'sigmoid' if _config[\"output_dim\"]==1 else 'softmax',\n",
    "                'explainer_head_num_attention_blocks': 1,\n",
    "                'explainer_head_include_cls': True,\n",
    "                'explainer_head_num_mlp_layers': 3,\n",
    "                'explainer_head_mlp_layer_ratio': 4,\n",
    "                'explainer_residual': [],\n",
    "                'explainer_freeze_backbone': \"all\"})\n",
    "\n",
    "explainer_dict = OrderedDict()\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    explainer_dict[backbone_type] = Explainer(normalization=_config[\"explainer_normalization\"],\n",
    "                                              normalization_class=_config[\"explainer_normalization_class\"],\n",
    "                                              activation=_config[\"explainer_activation\"],\n",
    "                                              surrogate=surrogate_dict[backbone_type][\"pre-softmax\"],\n",
    "                                              link=_config[\"explainer_link\"],\n",
    "                                              backbone_type=backbone_type,\n",
    "                                              download_weight=False,\n",
    "                                              residual=_config['explainer_residual'],\n",
    "                                              load_path=backbone_type_config[\"explainer_path\"],\n",
    "                                              target_type=_config[\"target_type\"],\n",
    "                                              output_dim=_config[\"output_dim\"],\n",
    "\n",
    "                                              explainer_head_num_attention_blocks=_config[\"explainer_head_num_attention_blocks\"],\n",
    "                                              explainer_head_include_cls=_config[\"explainer_head_include_cls\"],\n",
    "                                              explainer_head_num_mlp_layers=_config[\"explainer_head_num_mlp_layers\"],\n",
    "                                              explainer_head_mlp_layer_ratio=_config[\"explainer_head_mlp_layer_ratio\"],\n",
    "                                              explainer_norm=_config[\"explainer_norm\"],\n",
    "\n",
    "                                              efficiency_lambda=_config[\"explainer_efficiency_lambda\"],\n",
    "                                              efficiency_class_lambda=_config[\"explainer_efficiency_class_lambda\"],\n",
    "                                              freeze_backbone=_config[\"explainer_freeze_backbone\"],\n",
    "\n",
    "                                              checkpoint_metric=_config[\"checkpoint_metric\"],\n",
    "                                              optim_type=_config[\"optim_type\"],\n",
    "                                              learning_rate=_config[\"learning_rate\"],\n",
    "                                              weight_decay=_config[\"weight_decay\"],\n",
    "                                              decay_power=_config[\"decay_power\"],\n",
    "                                              warmup_steps=_config[\"warmup_steps\"]).to(_config[\"gpus_explainer\"][idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9e7f6fb",
   "metadata": {},
   "source": [
    "# explanation methods"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02048ea8",
   "metadata": {},
   "source": [
    "## attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b12d90ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_joint_attention(attentions, add_residual=True):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        attentions: (num_batches, num_layers, num_players, num_players)\n",
    "        add_residual: bool\n",
    "    Returns:\n",
    "        joint_attentions: (num_batches, num_layers, num_players, num_players)\n",
    "    \"\"\"\n",
    "    assert len(attentions.shape)==4\n",
    "    if add_residual:\n",
    "        residual_att = np.eye(attentions.shape[2])[np.newaxis, np.newaxis, ...]\n",
    "        aug_attentions = attentions + residual_att\n",
    "        aug_attentions = aug_attentions / aug_attentions.sum(axis=-1)[..., np.newaxis]\n",
    "    else:\n",
    "        aug_attentions =  attentions\n",
    "    \n",
    "    joint_attentions = np.zeros(aug_attentions.shape) # (num_batches, num_layers, num_players, num_players)\n",
    "\n",
    "    for i in np.arange(joint_attentions.shape[1]):\n",
    "        if i==0:\n",
    "            joint_attentions[:,i] = aug_attentions[:,0]\n",
    "        else:\n",
    "            joint_attentions[:,i] = (aug_attentions[:,i] @ joint_attentions[:,i-1])\n",
    "    return joint_attentions\n",
    "\n",
    "\n",
    "def attentions_to_explanation(attentions, mode='rollout'):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        attentions: (num_batches, num_layers, num_heads, num_players, num_players)\n",
    "    \"\"\"\n",
    "    assert len(attentions.shape)==5 and attentions.shape[-1]==attentions.shape[-2]\n",
    "    attentions_nohead = attentions.sum(axis=2)/attentions.shape[2] # (num_batch, num_layers, num_players, num_players)\n",
    "    attentions_nohead_residual = attentions_nohead + np.eye(attentions_nohead.shape[2])[np.newaxis, np.newaxis, ...] # (num_batch, num_layers, num_players, num_players)\n",
    "    attentions_nohead_residual_normalized = attentions_nohead_residual / attentions_nohead_residual.sum(axis=-1)[..., np.newaxis] # (num_batch, num_layers, num_players, num_players)\n",
    "    \n",
    "    if isinstance(mode, int):\n",
    "        return attentions_nohead_residual_normalized[:, mode, 0, 1:]\n",
    "    elif mode=='raw':\n",
    "        return attentions_nohead_residual_normalized[:, -1, 0, 1:]\n",
    "    elif mode=='rollout':\n",
    "        attentions_nohead_residual_normalized_rollout = compute_joint_attention(attentions_nohead_residual_normalized,\n",
    "                                                                                add_residual=False)\n",
    "        return attentions_nohead_residual_normalized_rollout[:, -1, 0, 1:]\n",
    "#explanation_to_mask(attention_rollout).argmin(axis=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "220f301b",
   "metadata": {},
   "source": [
    "## lrp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8045007",
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils.transformer_explainability.baselines.ViT.ViT_new as ViT_new\n",
    "import utils.transformer_explainability.baselines.ViT.ViT_LRP as ViT_LRP\n",
    "import utils.transformer_explainability.baselines.ViT.ViT_orig_LRP as ViT_orig_LRP\n",
    "\n",
    "from utils.transformer_explainability.baselines.ViT.ViT_explanation_generator import Baselines, LRP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e055b095",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "baselines_dict = OrderedDict()\n",
    "lrp_dict = OrderedDict()\n",
    "orig_lrp_dict = OrderedDict()\n",
    "\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    checkpoint = torch.load(backbone_type_config[\"classifier_path\"], map_location=\"cpu\")\n",
    "    checkpoint[\"state_dict\"]=OrderedDict([(k.replace('backbone.',''), v) for k, v in checkpoint[\"state_dict\"].items()])\n",
    "    state_dict = checkpoint[\"state_dict\"]\n",
    "    \n",
    "    model = getattr(ViT_new, backbone_type)(num_classes=_config[\"output_dim\"]).to(_config[\"gpus_classifier\"][idx])\n",
    "    ret = model.load_state_dict(state_dict, strict=False)\n",
    "    print(f\"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}\")\n",
    "    print(f\"Unmatched parameters - missing_keys:    {ret.missing_keys}\")\n",
    "    print(f\"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}\")\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        output1=model(x.to(next(model.parameters()).device))\n",
    "        output2=classifier_dict[backbone_type](x.to(next(model.parameters()).device))['logits']\n",
    "        assert torch.allclose(output1,output2,atol=1e-03)\n",
    "    baselines = Baselines(model)\n",
    "    baselines_dict[backbone_type]=baselines        \n",
    "    \n",
    "    model_LRP=getattr(ViT_LRP, backbone_type)(num_classes=_config[\"output_dim\"]).to(_config[\"gpus_classifier\"][idx])\n",
    "    ret = model_LRP.load_state_dict(state_dict, strict=False)\n",
    "    print(f\"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}\")\n",
    "    print(f\"Unmatched parameters - missing_keys:    {ret.missing_keys}\")\n",
    "    print(f\"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}\")\n",
    "    model_LRP.eval()      \n",
    "    lrp = LRP(model_LRP)\n",
    "    lrp_dict[backbone_type]=lrp\n",
    "    \n",
    "#     model_orig_LRP=getattr(ViT_orig_LRP, backbone_type)(num_classes=_config[\"output_dim\"]).to(_config[\"gpus_classifier\"][idx])\n",
    "#     ret = model_orig_LRP.load_state_dict(state_dict, strict=False)\n",
    "#     print(f\"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}\")\n",
    "#     print(f\"Unmatched parameters - missing_keys:    {ret.missing_keys}\")\n",
    "#     print(f\"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}\")\n",
    "#     model_orig_LRP.eval()    \n",
    "#     orig_lrp = LRP(model_orig_LRP)  \n",
    "#     orig_lrp_dict[backbone_type]=orig_lrp\n",
    "    \n",
    "    \n",
    "def get_lrp_module_explanation(backbone_type, original_image, class_index=None, mode='transformer_attribution'):\n",
    "    if mode==\"transformer_attribution\": # ours\n",
    "        transformer_attribution = lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(lrp_dict[backbone_type].model.parameters()).device), method=\"transformer_attribution\", index=class_index).detach()\n",
    "    elif mode==\"rollout\": # rollout\n",
    "        transformer_attribution = baselines_dict[backbone_type].generate_rollout(original_image.unsqueeze(0).to(next(baselines_dict[backbone_type].model.parameters()).device), start_layer=1).detach()\n",
    "    elif mode==\"attn_last_layer\": # raw-attention\n",
    "        transformer_attribution = lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(lrp_dict[backbone_type].model.parameters()).device), method=\"last_layer_attn\", index=class_index).detach()\n",
    "    elif mode == 'attn_gradcam': # GradCAM\n",
    "        transformer_attribution = baselines_dict[backbone_type].generate_cam_attn(original_image.unsqueeze(0).to(next(baselines_dict[backbone_type].model.parameters()).device), index=class_index).detach()\n",
    "        transformer_attribution = transformer_attribution.reshape(1,-1)\n",
    "        #transformer_attribution=torch.nan_to_num(transformer_attribution,nan=0)\n",
    "        #transformer_attribution+=torch.rand(size=transformer_attribution.shape, device=transformer_attribution.device)*1e-20        \n",
    "        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
    "    elif mode == 'full_lrp':\n",
    "        transformer_attribution = orig_lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(orig_lrp_dict[backbone_type].model.parameters()).device), method=\"full\", index=class_index).detach()\n",
    "        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
    "    elif mode == 'lrp_last_layer':\n",
    "        transformer_attribution = orig_lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(orig_lrp_dict[backbone_type].model.parameters()).device), method=\"last_layer\", index=class_index).detach()\n",
    "        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
    "    #print(transformer_attribution.max(), transformer_attribution.min())\n",
    "    #print(transformer_attribution.shape)\n",
    "    return transformer_attribution    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40064fa3",
   "metadata": {},
   "source": [
    "## CAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb8e73b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.pytorch_grad_cam import GradCAM\n",
    "from utils.pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
    "\n",
    "def reshape_transform(tensor, height=14, width=14):\n",
    "    result = tensor[:, 1 :  , :].reshape(tensor.size(0),\n",
    "        height, width, tensor.size(2))\n",
    "\n",
    "    # Bring the channels to the first dimension,\n",
    "    # like in CNNs.\n",
    "    result = result.transpose(2, 3).transpose(1, 2)\n",
    "    return result\n",
    "\n",
    "class WrapperLogits(nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model=model\n",
    "\n",
    "    def forward(self, images):\n",
    "        x = self.model(images)\n",
    "        return x['logits']\n",
    "\n",
    "cam_dict = OrderedDict()\n",
    "for backbone_type, backbone_type_config in backbone_type_config_dict.items():\n",
    "    cam_dict[backbone_type] = GradCAM(model=WrapperLogits(classifier_dict[backbone_type]),\n",
    "                                      target_layers=[classifier_dict[backbone_type].backbone.blocks[-1].norm1],\n",
    "                                      reshape_transform=reshape_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48e2b912",
   "metadata": {},
   "source": [
    "## Gradient-based"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05bc9f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from captum.attr import IntegratedGradients, InputXGradient, Saliency, NoiseTunnel\n",
    "import torch.nn as nn\n",
    "\n",
    "class FromPixel(nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model=model\n",
    "\n",
    "    def forward(self, images):\n",
    "        x = self.model.backbone.patch_embed(images)\n",
    "        x = self.model.backbone.forward_features(x)['x']\n",
    "        logits = self.model.head(x)\n",
    "        \n",
    "        if _config[\"output_dim\"]==1:\n",
    "            return logits.sigmoid()\n",
    "        else:\n",
    "            return logits.softmax(dim=-1)\n",
    "    \n",
    "class FromEmbedding(nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model=model\n",
    "\n",
    "    def forward(self, embedding):\n",
    "        x = self.model.backbone.forward_features(embedding)['x']\n",
    "        logits = self.model.head(x)\n",
    "        \n",
    "        if _config[\"output_dim\"]==1:\n",
    "            return logits.sigmoid()\n",
    "        else:\n",
    "            return logits.softmax(dim=-1)\n",
    "\n",
    "#Classifier Wrapping    \n",
    "classifier_pixel_dict = OrderedDict()\n",
    "classifier_embedding_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    classifier_pixel_dict[backbone_type]=FromPixel(classifier_dict_[backbone_type])\n",
    "    classifier_embedding_dict[backbone_type]=FromEmbedding(classifier_dict_[backbone_type])\n",
    "\n",
    "#Vanilla\n",
    "saliency_pixel_dict = OrderedDict()\n",
    "saliency_embedding_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    saliency_pixel_dict[backbone_type] = Saliency(classifier_pixel_dict[backbone_type])\n",
    "    saliency_embedding_dict[backbone_type] = Saliency(classifier_embedding_dict[backbone_type])      \n",
    "\n",
    "#NoiseTunnel\n",
    "noisetunnel_pixel_dict = OrderedDict()\n",
    "noisetunnel_embedding_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    noisetunnel_pixel_dict[backbone_type] = NoiseTunnel(saliency_pixel_dict[backbone_type])\n",
    "    noisetunnel_embedding_dict[backbone_type] = NoiseTunnel(saliency_embedding_dict[backbone_type])      \n",
    "\n",
    "#IntegratedGradients    \n",
    "ig_pixel_dict = OrderedDict()\n",
    "ig_embedding_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    ig_pixel_dict[backbone_type] = IntegratedGradients(classifier_pixel_dict[backbone_type])\n",
    "    ig_embedding_dict[backbone_type] = IntegratedGradients(classifier_embedding_dict[backbone_type])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdb4e500",
   "metadata": {},
   "outputs": [],
   "source": [
    "def attributions_pixel_process(attributions_pixel):\n",
    "    attributions_pixel_sum = attributions_pixel.sum(axis=-3)\n",
    "    attributions_pixel_abssum = attributions_pixel.abs().sum(axis=-3)\n",
    "    attributions_pixel_patchsum = F.conv2d(attributions_pixel,\n",
    "                                           weight=torch.ones(size=(1, 3, 16, 16),\n",
    "                                                             dtype=attributions_pixel.dtype,\n",
    "                                                             device=attributions_pixel.device),\n",
    "                                           stride=16).squeeze(axis=1)#.flatten(1, 2)\n",
    "    attributions_pixel_pathabssum = F.conv2d(attributions_pixel.abs(),\n",
    "                                             weight=torch.ones(size=(1, 3, 16, 16),\n",
    "                                                               dtype=attributions_pixel.dtype,\n",
    "                                                               device=attributions_pixel.device),\n",
    "                                             stride=16).squeeze(axis=1)#.flatten(1, 2) \n",
    "    \n",
    "    return {'attributions_pixel_sum': attributions_pixel_sum.detach().cpu(),# makes sense? (but cannot used for benchmarking)\n",
    "            'attributions_pixel_abssum': attributions_pixel_abssum.detach().cpu(),# makes sense (but cannot used for benchmarking)\n",
    "            'attributions_pixel_patchsum': attributions_pixel_patchsum.detach().cpu(),  # makes sense?\n",
    "            'attributions_pixel_patchabssum': attributions_pixel_pathabssum.detach().cpu()  # makes sense    \n",
    "           }\n",
    "    \n",
    "    \n",
    "def attributions_embedding_process(attributions_embedding):\n",
    "    attributions_embedding_sum = attributions_embedding.sum(axis=-1)\n",
    "    attributions_embedding_abssum = attributions_embedding.abs().sum(axis=-1)\n",
    "    return {'attributions_embedding_sum': attributions_embedding_sum.detach().cpu(), # makes sense?\n",
    "            'attributions_embedding_abssum': attributions_embedding_abssum.detach().cpu() # makes sense\n",
    "           }  \n",
    "\n",
    "def get_vanilla(image, saliency_pixel=None, saliency_embedding=None):\n",
    "    result={}\n",
    "    with torch.no_grad():\n",
    "        if saliency_pixel is not None:\n",
    "            attributions_pixel = [saliency_pixel.attribute(inputs=image.unsqueeze(0).to(next(saliency_pixel.forward_func.parameters()).device), \n",
    "                                                           target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_pixel = torch.concat(attributions_pixel)\n",
    "            result.update(attributions_pixel_process(attributions_pixel))\n",
    "            \n",
    "        if saliency_embedding is not None:\n",
    "            attributions_embedding = [saliency_embedding.attribute(inputs=saliency_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(saliency_embedding.forward_func.parameters()).device)).detach(),\n",
    "                                                                   target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_embedding = torch.concat(attributions_embedding)\n",
    "            result.update(attributions_embedding_process(attributions_embedding))\n",
    "        \n",
    "    return result\n",
    "\n",
    "def get_sg(image, noisetunnel_pixel=None, noisetunnel_embedding=None):\n",
    "    result={}    \n",
    "    with torch.no_grad():\n",
    "        if noisetunnel_pixel is not None:\n",
    "            attributions_pixel = [noisetunnel_pixel.attribute(inputs=image.unsqueeze(0).to(next(noisetunnel_pixel.forward_func.parameters()).device),\n",
    "                                                              nt_type='smoothgrad',\n",
    "                                                              nt_samples=10,\n",
    "                                                              target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_pixel = torch.concat(attributions_pixel)\n",
    "            result.update(attributions_pixel_process(attributions_pixel))\n",
    "\n",
    "        if noisetunnel_embedding is not None:\n",
    "            attributions_embedding  = [noisetunnel_embedding.attribute(inputs=noisetunnel_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(noisetunnel_embedding.forward_func.parameters()).device)).detach(),\n",
    "                                                                       nt_type='smoothgrad',\n",
    "                                                                       nt_samples=10,            \n",
    "                                                                       target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_embedding = torch.concat(attributions_embedding)\n",
    "            result.update(attributions_embedding_process(attributions_embedding))\n",
    "\n",
    "        \n",
    "    return result\n",
    "\n",
    "def get_vargrad(image, noisetunnel_pixel=None, noisetunnel_embedding=None):\n",
    "    result={}    \n",
    "    with torch.no_grad():\n",
    "        if noisetunnel_pixel is not None:\n",
    "            attributions_pixel = [noisetunnel_pixel.attribute(inputs=image.unsqueeze(0).to(next(noisetunnel_pixel.forward_func.parameters()).device),\n",
    "                                                              nt_type='vargrad',\n",
    "                                                              nt_samples=10,\n",
    "                                                              target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_pixel = torch.concat(attributions_pixel)\n",
    "            result.update(attributions_pixel_process(attributions_pixel))   \n",
    "\n",
    "        if noisetunnel_embedding is not None:\n",
    "            attributions_embedding  = [noisetunnel_embedding.attribute(inputs=noisetunnel_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(noisetunnel_embedding.forward_func.parameters()).device)).detach(),\n",
    "                                                                       nt_type='vargrad',\n",
    "                                                                       nt_samples=10,            \n",
    "                                                                       target=i) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_embedding = torch.concat(attributions_embedding)\n",
    "            result.update(attributions_embedding_process(attributions_embedding))            \n",
    "\n",
    "        \n",
    "    return result\n",
    "\n",
    "def get_ig(image, ig_pixel=None, ig_embedding=None):\n",
    "    result={}    \n",
    "    with torch.no_grad():\n",
    "        if ig_pixel is not None:\n",
    "            attributions_pixel = [ig_pixel.attribute(inputs=image.unsqueeze(0).to(next(ig_pixel.forward_func.parameters()).device),\n",
    "                                                                            baselines=torch.zeros_like(image).unsqueeze(0).to(next(ig_pixel.forward_func.parameters()).device),\n",
    "                                                                            target=i,\n",
    "                                                                            n_steps=10) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_pixel = torch.concat(attributions_pixel)\n",
    "            result.update(attributions_pixel_process(attributions_pixel))           \n",
    "\n",
    "        if ig_embedding is not None:\n",
    "            attributions_embedding = [ig_embedding.attribute(inputs=ig_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(ig_embedding.forward_func.parameters()).device)).detach(),\n",
    "                                                                                           baselines=ig_embedding.forward_func.model.backbone.patch_embed(torch.zeros_like(image).unsqueeze(0).to(next(ig_embedding.forward_func.parameters()).device)).detach(),\n",
    "                                                                                           target=i,\n",
    "                                                                                           n_steps=10) for i in range(_config[\"output_dim\"])]\n",
    "\n",
    "            attributions_embedding = torch.concat(attributions_embedding)\n",
    "            result.update(attributions_embedding_process(attributions_embedding))          \n",
    "        \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c24d0766",
   "metadata": {},
   "source": [
    "## leave-one-out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "232428ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def leave_one_out(image, surrogate=None, classifier=None):\n",
    "    with torch.no_grad():\n",
    "        mask=torch.cat([torch.ones(1, 196) ,1-torch.eye(196)])\n",
    "        if surrogate is not None:\n",
    "            out=surrogate(image.unsqueeze(0).repeat(196+1, 1, 1, 1).to(surrogate.device), \n",
    "                          masks=mask.to(surrogate.device))\n",
    "        elif classifier is not None:\n",
    "            mask_scaled = torch.repeat_interleave(torch.repeat_interleave(mask.reshape(-1, 14, 14), 16, dim=2), 16, dim=1)\n",
    "            image_masked=image * mask_scaled.unsqueeze(1)\n",
    "            \n",
    "            if classifier.__class__==Classifier:\n",
    "                out=classifier(image_masked.to(classifier.device))\n",
    "            elif classifier.__class__==Surrogate:\n",
    "                out=classifier(image_masked.to(classifier.device),\n",
    "                              masks=torch.ones((len(image_masked),196)))\n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "        if _config[\"output_dim\"]==1:\n",
    "            prob=out['logits'].sigmoid().detach().cpu().numpy()\n",
    "        else:\n",
    "            prob=out['logits'].softmax(dim=-1).detach().cpu().numpy()    \n",
    "        \n",
    "        result=prob[0:1]-prob[1:]\n",
    "\n",
    "    return result.transpose(1,0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ef518bd",
   "metadata": {},
   "source": [
    "# RISE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f988c172",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rise(image, surrogate=None, classifier=None, include_prob=0.5, N=2000):\n",
    "    assert (surrogate is None) != (classifier is None)\n",
    "    \n",
    "    prob_list=[]\n",
    "    mask_list=[]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(N//100):\n",
    "            mask=torch.rand(100, 196)<include_prob\n",
    "            if surrogate is not None:\n",
    "                out=surrogate(image.unsqueeze(0).repeat(100, 1, 1, 1).to(surrogate.device), \n",
    "                              masks=(mask).to(surrogate.device))\n",
    "            elif classifier is not None:\n",
    "                mask_scaled = torch.repeat_interleave(torch.repeat_interleave(mask.reshape(-1, 14, 14), 16, dim=2), 16, dim=1)\n",
    "                image_masked = image * mask_scaled.unsqueeze(1)\n",
    "                del mask_scaled\n",
    "                if classifier.__class__==Classifier:\n",
    "                    out=classifier(image_masked.to(classifier.device))\n",
    "                elif classifier.__class__==Surrogate:\n",
    "                    out=classifier(image_masked.to(classifier.device),\n",
    "                                  masks=torch.ones_like(mask))\n",
    "                else:\n",
    "                    raise\n",
    "                #out=surrogate_dict[backbone_type](image_masked.to(surrogate_dict[backbone_type].device), \n",
    "                #             masks=torch.ones((100,196)).to(surrogate_dict[backbone_type].device))                \n",
    "            else:\n",
    "                raise\n",
    "            \n",
    "            if _config[\"output_dim\"]==1:\n",
    "                prob=out['logits'].sigmoid().detach().cpu().numpy()\n",
    "            else:\n",
    "                prob=out['logits'].softmax(dim=-1).detach().cpu().numpy()    \n",
    "            \n",
    "            del out\n",
    "            prob_list.append(prob)\n",
    "            mask_list.append(mask.numpy())\n",
    "            del mask\n",
    "            \n",
    "            \n",
    "    prob_list_array=np.concatenate(prob_list) # (num_trials, num_classes)\n",
    "    mask_list_array=np.concatenate(mask_list) # (num_trials, num_players)\n",
    "\n",
    "    result = (prob_list_array.T @ mask_list_array) # (num_classes, num_players)\n",
    "    result = result/mask_list_array.sum(axis=0)\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e1e499",
   "metadata": {},
   "source": [
    "# KernelSHAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6fb0a14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.shapreg import removal, games, shapley\n",
    "\n",
    "class SurrogateSHAPWrapper(nn.Module):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.model=model\n",
    "        if _config[\"output_dim\"]==1:\n",
    "            self.activation=nn.Sigmoid()\n",
    "        else:\n",
    "            self.activation=nn.Softmax(dim=-1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        images, mask = x\n",
    "        mask = mask.squeeze(1).flatten(1)\n",
    "        out=self.model(images, mask)['logits']\n",
    "        out=self.activation(out)\n",
    "        return out\n",
    "\n",
    "surrogate_SHAP_wrapped_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    surrogate_SHAP_wrapped_dict[backbone_type]=SurrogateSHAPWrapper(surrogate_dict[backbone_type][\"pre-softmax\"])    \n",
    "\n",
    "def get_shap(surrogate_SHAP_wrapped, x, batch_size=64, thresh=0.2, variance_batches=60):\n",
    "    game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped, x)\n",
    "    explanation = shapley.ShapleyRegression(game, batch_size=batch_size, thresh=thresh, variance_batches=variance_batches)\n",
    "    return explanation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f3c4f5f",
   "metadata": {},
   "source": [
    "# save_dict_setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9cee727",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    explanation_save_dict_backbone={\"random\":{},\n",
    "                                    \"attention_rollout\":{},\n",
    "                                    \"attention_last\":{},\n",
    "                                    \"LRP\":{},\n",
    "                                    \"gradcam\":{},\n",
    "                                    \"gradcamgithub\": {},\n",
    "                                    \"vanillapixel\": {},\n",
    "                                    \"vanillaembedding\": {},\n",
    "                                    \"sgpixel\": {},\n",
    "                                    \"sgembedding\": {},\n",
    "                                    \"vargradpixel\": {},\n",
    "                                    \"vargradembedding\": {},               \n",
    "                                    \"igpixel\": {},\n",
    "                                    \"igembedding\": {},\n",
    "                                    \"leaveoneoutclassifier\": {},\n",
    "                                    \"leaveoneoutsurrogate\": {},\n",
    "                                    \"riseclassifier\": {},\n",
    "                                    \"risesurrogate\": {},\n",
    "                                    \"ours\": {},\n",
    "                                    \"kernelshap\": {}\n",
    "                                    }\n",
    "    explanation_save_dict[backbone_type]=explanation_save_dict_backbone\n",
    "    \n",
    "def explanation_save_dict_update(backbone_type, explanation_method,\n",
    "                                 path_list, explanation_list, elapsed_time_list, \n",
    "                                 shape=None):\n",
    "    explanation_save_dict_backbone_method=explanation_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(explanation_list) == len(elapsed_time_list)\n",
    "    \n",
    "    for explanation, path, elapsed_time in zip(explanation_list, path_list, elapsed_time_list):\n",
    "        assert type(explanation)==np.ndarray\n",
    "        assert type(path)==str\n",
    "        assert type(elapsed_time)==float\n",
    "        if shape is not None:\n",
    "            assert explanation.shape==shape\n",
    "        explanation_save_dict_backbone_method[path]={\"explanation\": explanation.astype(float),\n",
    "                                                     \"elapsed_time\": elapsed_time}    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d38051",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():\n",
    "        try:\n",
    "            explanation_save_dict_path=f'results/3_explanation_generate/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "            if os.path.isfile(explanation_save_dict_path):\n",
    "                with open(explanation_save_dict_path, 'rb') as f:\n",
    "                    explanation_save_dict_loaded=pickle.load(f)\n",
    "            else:\n",
    "                explanation_save_dict_loaded={}\n",
    "\n",
    "            len_original=len(explanation_save_dict_backbone_method)            \n",
    "            len_loaded=len(explanation_save_dict_loaded)\n",
    "            explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)\n",
    "            len_updated=len(explanation_save_dict_backbone_method)\n",
    "\n",
    "            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}') \n",
    "        except:\n",
    "            print('aa')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a028c2cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "insertdelete_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    insertdelete_save_dict_backbone={\"random\":{},\n",
    "                                     \"attention_rollout\":{},\n",
    "                                     \"attention_last\":{},\n",
    "                                     \"LRP\":{},\n",
    "                                     \"gradcam\":{},\n",
    "                                     \"gradcamgithub\": {},\n",
    "                                     \"vanillapixel\": {},\n",
    "                                     \"vanillaembedding\": {},\n",
    "                                     \"sgpixel\": {},\n",
    "                                     \"sgembedding\": {},\n",
    "                                     \"vargradpixel\": {},\n",
    "                                     \"vargradembedding\": {},               \n",
    "                                     \"igpixel\": {},\n",
    "                                     \"igembedding\": {},\n",
    "                                     \"leaveoneoutclassifier\": {},\n",
    "                                     \"leaveoneoutsurrogate\": {},\n",
    "                                     \"riseclassifier\": {},\n",
    "                                     \"risesurrogate\": {},\n",
    "                                     \"ours\": {},\n",
    "                                     \"kernelshap\":{}\n",
    "                                    }\n",
    "    insertdelete_save_dict[backbone_type]=insertdelete_save_dict_backbone\n",
    "    \n",
    "def insertdelete_save_dict_update(backbone_type, explanation_method,\n",
    "                                 path_list, insert_list, delete_list,\n",
    "                                 shape=None):\n",
    "    insertdelete_save_dict_backbone_method=insertdelete_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(insert_list) == len(delete_list)\n",
    "    \n",
    "    for insert, delete, path in zip(insert_list, delete_list, path_list):\n",
    "        assert type(insert)==np.ndarray\n",
    "        assert type(delete)==np.ndarray\n",
    "        assert type(path)==str\n",
    "        if shape is not None:\n",
    "            assert insert.shape==shape\n",
    "            assert delete.shape==shape\n",
    "        insertdelete_save_dict_backbone_method[path]={\"insert\": insert.astype(float),\n",
    "                                                      \"delete\": delete.astype(float)\n",
    "                                                      }    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34817dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():\n",
    "        insertdelete_save_dict_path=f'results/4_insert_delete/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(insertdelete_save_dict_path):\n",
    "            with open(insertdelete_save_dict_path, 'rb') as f:\n",
    "                insertdelete_save_dict_loaded=pickle.load(f)\n",
    "        else:\n",
    "            insertdelete_save_dict_loaded={}\n",
    "            \n",
    "        len_original=len(insertdelete_save_dict_backbone_method)            \n",
    "        len_loaded=len(insertdelete_save_dict_loaded)\n",
    "        insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)\n",
    "        len_updated=len(insertdelete_save_dict_backbone_method)\n",
    "            \n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a38937",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitivity_save_dit={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    sensitivity_save_dit_backbone={\"attention_rollout\":{},\n",
    "                                   \"attention_last\":{},\n",
    "                                   \"LRP\":{},\n",
    "                                   \"gradcam\":{},\n",
    "                                   \"gradcamgithub\": {},\n",
    "                                   \"vanillapixel\": {},\n",
    "                                   \"vanillaembedding\": {},\n",
    "                                   \"sgpixel\": {},\n",
    "                                   \"sgembedding\": {},\n",
    "                                   \"vargradpixel\": {},\n",
    "                                   \"vargradembedding\": {},               \n",
    "                                   \"igpixel\": {},\n",
    "                                   \"igembedding\": {},\n",
    "                                   \"leaveoneoutclassifier\": {},\n",
    "                                   \"leaveoneoutsurrogate\": {},\n",
    "                                   \"riseclassifier\": {},\n",
    "                                   \"risesurrogate\": {},\n",
    "                                   \"ours\": {},\n",
    "                                   }\n",
    "    sensitivity_save_dit[backbone_type]=sensitivity_save_dit_backbone\n",
    "    \n",
    "def sensitivity_save_dit_update(backbone_type, explanation_method, num_included_players,\n",
    "                                path_list, sensitivity_list,\n",
    "                                shape=None):\n",
    "    \n",
    "    sensitivity_save_dit_backbone_method=sensitivity_save_dit[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(sensitivity_list)\n",
    "    \n",
    "    for sensitivity, path in zip(sensitivity_list, path_list):\n",
    "        assert type(sensitivity)==np.ndarray\n",
    "        assert type(path)==str\n",
    "        if shape is not None:\n",
    "            assert sensitivity.shape==shape\n",
    "        sensitivity_save_dit_backbone_method.setdefault(path, {})\n",
    "        sensitivity_save_dit_backbone_method[path][num_included_players]=sensitivity.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df2a81b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, sensitivity_save_dit_backbone_method in sensitivity_save_dit[backbone_type].items():\n",
    "        sensitivity_save_dit_path=f'results/5_sensitivity/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(sensitivity_save_dit_path):\n",
    "            with open(sensitivity_save_dit_path, 'rb') as f:\n",
    "                sensitivity_save_dit_loaded=pickle.load(f)\n",
    "        else:\n",
    "            sensitivity_save_dit_loaded={}\n",
    "\n",
    "        len_original=len(sensitivity_save_dit_backbone_method)            \n",
    "        len_loaded=len(sensitivity_save_dit_loaded)\n",
    "        sensitivity_save_dit_backbone_method.update(sensitivity_save_dit_loaded)\n",
    "        len_updated=len(sensitivity_save_dit_backbone_method)\n",
    "\n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c313913",
   "metadata": {},
   "outputs": [],
   "source": [
    "noretraining_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    noretraining_save_dict_backbone={\"random\":{},\n",
    "                                     \"attention_rollout\":{},\n",
    "                                     \"attention_last\":{},\n",
    "                                     \"LRP\":{},\n",
    "                                     \"gradcam\":{},\n",
    "                                     \"gradcamgithub\": {},\n",
    "                                     \"vanillapixel\": {},\n",
    "                                     \"vanillaembedding\": {},\n",
    "                                     \"sgpixel\": {},\n",
    "                                     \"sgembedding\": {},\n",
    "                                     \"vargradpixel\": {},\n",
    "                                     \"vargradembedding\": {},               \n",
    "                                     \"igpixel\": {},\n",
    "                                     \"igembedding\": {},\n",
    "                                     \"leaveoneoutclassifier\": {},\n",
    "                                     \"leaveoneoutsurrogate\": {},\n",
    "                                     \"riseclassifier\": {},\n",
    "                                     \"risesurrogate\": {},\n",
    "                                     \"ours\": {},\n",
    "                                    }\n",
    "    noretraining_save_dict[backbone_type]=noretraining_save_dict_backbone\n",
    "    \n",
    "def noretraining_save_dict_update(backbone_type, explanation_method,\n",
    "                                 path_list, insert_list, delete_list,\n",
    "                                 shape=None):\n",
    "    noretraining_save_dict_backbone_method=noretraining_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(insert_list) == len(delete_list)\n",
    "    \n",
    "    for insert, delete, path in zip(insert_list, delete_list, path_list):\n",
    "        assert type(insert)==np.ndarray\n",
    "        assert type(delete)==np.ndarray\n",
    "        assert type(path)==str\n",
    "        if shape is not None:\n",
    "            assert insert.shape==shape\n",
    "            assert delete.shape==shape\n",
    "        noretraining_save_dict_backbone_method[path]={\"insert\": insert.astype(float),\n",
    "                                                      \"delete\": delete.astype(float)\n",
    "                                                      }    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "644df8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, noretraining_save_dict_backbone_method in noretraining_save_dict[backbone_type].items():\n",
    "        noretraining_save_dict_path=f'results/6_noretraining/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(noretraining_save_dict_path):\n",
    "            with open(noretraining_save_dict_path, 'rb') as f:\n",
    "                noretraining_save_dict_loaded=pickle.load(f)\n",
    "        else:\n",
    "            noretraining_save_dict_loaded={}\n",
    "            \n",
    "        len_original=len(noretraining_save_dict_backbone_method)            \n",
    "        len_loaded=len(noretraining_save_dict_loaded)\n",
    "        noretraining_save_dict_backbone_method.update(noretraining_save_dict_loaded)\n",
    "        len_updated=len(noretraining_save_dict_backbone_method)\n",
    "            \n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0cc56a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifiermasked_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    classifiermasked_save_dict_backbone={\"random\":{},\n",
    "                                     \"attention_rollout\":{},\n",
    "                                     \"attention_last\":{},\n",
    "                                     \"LRP\":{},\n",
    "                                     \"gradcam\":{},\n",
    "                                     \"gradcamgithub\": {},\n",
    "                                     \"vanillapixel\": {},\n",
    "                                     \"vanillaembedding\": {},\n",
    "                                     \"sgpixel\": {},\n",
    "                                     \"sgembedding\": {},\n",
    "                                     \"vargradpixel\": {},\n",
    "                                     \"vargradembedding\": {},               \n",
    "                                     \"igpixel\": {},\n",
    "                                     \"igembedding\": {},\n",
    "                                     \"leaveoneoutclassifier\": {},\n",
    "                                     \"leaveoneoutsurrogate\": {},\n",
    "                                     \"riseclassifier\": {},\n",
    "                                     \"risesurrogate\": {},\n",
    "                                     \"ours\": {},\n",
    "                                    }\n",
    "    classifiermasked_save_dict[backbone_type]=classifiermasked_save_dict_backbone\n",
    "    \n",
    "def classifiermasked_save_dict_update(backbone_type, explanation_method,\n",
    "                                 path_list, insert_list, delete_list,\n",
    "                                 shape=None):\n",
    "    classifiermasked_save_dict_backbone_method=classifiermasked_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(insert_list) == len(delete_list)\n",
    "    \n",
    "    for insert, delete, path in zip(insert_list, delete_list, path_list):\n",
    "        assert type(insert)==np.ndarray\n",
    "        assert type(delete)==np.ndarray\n",
    "        assert type(path)==str\n",
    "        if shape is not None:\n",
    "            assert insert.shape==shape\n",
    "            assert delete.shape==shape\n",
    "        classifiermasked_save_dict_backbone_method[path]={\"insert\": insert.astype(float),\n",
    "                                                      \"delete\": delete.astype(float)\n",
    "                                                      }    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdfd4cb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, classifiermasked_save_dict_backbone_method in classifiermasked_save_dict[backbone_type].items():\n",
    "        classifiermasked_save_dict_path=f'results/7_classifiermasked/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(classifiermasked_save_dict_path):\n",
    "            with open(classifiermasked_save_dict_path, 'rb') as f:\n",
    "                classifiermasked_save_dict_loaded=pickle.load(f)\n",
    "        else:\n",
    "            classifiermasked_save_dict_loaded={}\n",
    "            \n",
    "        len_original=len(classifiermasked_save_dict_backbone_method)            \n",
    "        len_loaded=len(classifiermasked_save_dict_loaded)\n",
    "        classifiermasked_save_dict_backbone_method.update(classifiermasked_save_dict_loaded)\n",
    "        len_updated=len(classifiermasked_save_dict_backbone_method)\n",
    "            \n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882fc10e",
   "metadata": {},
   "outputs": [],
   "source": [
    "elapsedtime_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    elapsedtime_save_dict_backbone={\"random\":{},\n",
    "                                     \"attention_rollout\":{},\n",
    "                                     \"attention_last\":{},\n",
    "                                     \"LRP\":{},\n",
    "                                     \"gradcam\":{},\n",
    "                                     \"gradcamgithub\": {},\n",
    "                                     \"vanillapixel\": {},\n",
    "                                     \"vanillaembedding\": {},\n",
    "                                     \"sgpixel\": {},\n",
    "                                     \"sgembedding\": {},\n",
    "                                     \"vargradpixel\": {},\n",
    "                                     \"vargradembedding\": {},               \n",
    "                                     \"igpixel\": {},\n",
    "                                     \"igembedding\": {},\n",
    "                                     \"leaveoneoutclassifier\": {},\n",
    "                                     \"leaveoneoutsurrogate\": {},\n",
    "                                     \"riseclassifier\": {},\n",
    "                                     \"risesurrogate\": {},\n",
    "                                     \"ours\": {},\n",
    "                                    }\n",
    "    elapsedtime_save_dict[backbone_type]=elapsedtime_save_dict_backbone\n",
    "    \n",
    "def elapsedtime_save_dict_update(backbone_type, explanation_method,\n",
    "                                 path_list, elapsed_time_list,\n",
    "                                 shape=None):\n",
    "    elapsedtime_save_dict_backbone_method=elapsedtime_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(elapsed_time_list)\n",
    "    \n",
    "    for elapsed_time, path in zip(elapsed_time_list, path_list):\n",
    "        assert type(elapsed_time)==float\n",
    "        assert type(path)==str\n",
    "        elapsedtime_save_dict_backbone_method[path]={\"time\": elapsed_time}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50926815",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():\n",
    "        elapsedtime_save_dict_path=f'results/8_elapsedtime/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(elapsedtime_save_dict_path):\n",
    "            with open(elapsedtime_save_dict_path, 'rb') as f:\n",
    "                elapsedtime_save_dict_loaded=pickle.load(f)\n",
    "        else:\n",
    "            elapsedtime_save_dict_loaded={}\n",
    "            \n",
    "        len_original=len(elapsedtime_save_dict_backbone_method)            \n",
    "        len_loaded=len(elapsedtime_save_dict_loaded)\n",
    "        elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)\n",
    "        len_updated=len(elapsedtime_save_dict_backbone_method)\n",
    "            \n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bf838f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimationerror_save_dict={}\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    estimationerror_save_dict_backbone={\"kernelshap\":{},\n",
    "                                        \"kernelshapnopair\":{},\n",
    "                                        \"ours\": {}\n",
    "                                        }\n",
    "    estimationerror_save_dict[backbone_type]=estimationerror_save_dict_backbone\n",
    "    \n",
    "def estimationerror_save_dict_update(backbone_type, explanation_method,\n",
    "                                     path_list, estimation_list, label_list,\n",
    "                                     shape=None):\n",
    "    estimationerror_save_dict_backbone_method=estimationerror_save_dict[backbone_type][explanation_method]\n",
    "        \n",
    "    assert len(path_list) == len(estimation_list) == len(label_list)\n",
    "    \n",
    "    for path, estimation, label in zip(path_list, estimation_list, label_list):\n",
    "        assert type(path)==str\n",
    "        #assert type(estimation)==np.ndarray\n",
    "        assert type(label)==int\n",
    "        \n",
    "        if shape is not None:\n",
    "            assert estimation.shape==shape        \n",
    "        \n",
    "        estimationerror_save_dict_backbone_method[path]={\"estimation\": estimation,\n",
    "                                                         \"label\": label}        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cfd7ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():\n",
    "        estimationerror_save_dict_path=f'results/9_estimationerror/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "        if os.path.isfile(estimationerror_save_dict_path):\n",
    "            with open(estimationerror_save_dict_path, 'rb') as f:\n",
    "                estimationerror_save_dict_loaded=pickle.load(f)\n",
    "        else:\n",
    "            estimationerror_save_dict_loaded={}\n",
    "            \n",
    "        len_original=len(estimationerror_save_dict_backbone_method)            \n",
    "        len_loaded=len(estimationerror_save_dict_loaded)\n",
    "        estimationerror_save_dict_backbone_method.update(estimationerror_save_dict_loaded)\n",
    "        len_updated=len(estimationerror_save_dict_backbone_method)\n",
    "            \n",
    "        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d7c1111",
   "metadata": {},
   "source": [
    "# utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16405739",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_relative_value(x, random_seed=None):\n",
    "    assert len(x.shape)==1\n",
    "    \n",
    "    if isinstance(random_seed, int):\n",
    "        rng = np.random.default_rng(random_seed)\n",
    "        perm = rng.permutation(np.arange(len(x)))\n",
    "    else:\n",
    "        perm = np.random.permutation(np.arange(len(x)))    \n",
    "\n",
    "    argsorted=np.arange(len(x))[perm][np.argsort(x[perm])]\n",
    "    relative_value=np.argsort(argsorted)\n",
    "\n",
    "    return relative_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252a017d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def adapt_path(path_original, dict_keys):\n",
    "    path_list = ['l0.cs.hostname', 'l1lambda.cs.hostname', 'l2lambda.cs.hostname',\n",
    "                 'l3.cs.hostname', 'deeper.cs.hostname', 'sync']\n",
    "\n",
    "    dict_keys=list(dict_keys)\n",
    "\n",
    "\n",
    "    for path1 in path_list:\n",
    "        if path1 in path_original:\n",
    "            for path2 in path_list:\n",
    "                path_replaced=path_original.replace(path1, path2)\n",
    "                if path_replaced in dict_keys:\n",
    "                    return path_replaced\n",
    "    return path_original\n",
    "    #raise ValueError(f\"not found {path_original}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aef1c66",
   "metadata": {},
   "source": [
    "# Methods to run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b71c18f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "explanation_method_to_run_=[\"random\", \"attention_rollout\", \"attention_last\", \n",
    "                            \"LRP\", \"gradcam\", \"gradcamgithub\",\n",
    "                            \"vanillapixel\", \"vanillaembedding\",\n",
    "                            \"sgpixel\", \"sgembedding\",\n",
    "                            \"vargradpixel\", \"vargradembedding\",\n",
    "                            \"igpixel\", \"igembedding\",                           \n",
    "                            \"leaveoneoutclassifier\",\n",
    "                            \"riseclassifier\", \n",
    "                            \"ours\"]\n",
    "#explanation_method_to_run_=[\"kernelshap\"]\n",
    "# explanation_method_to_run_=[\"random\", \"attention_rollout\", \"attention_last\", \n",
    "#                             \"LRP\", \"gradcam\", \n",
    "#                             \"vanillaembedding\",\n",
    "#                             \"sgembedding\",\n",
    "#                             \"vargradembedding\",\n",
    "#                             \"igembedding\",                           \n",
    "#                             \"leaveoneoutclassifier\",\n",
    "#                             \"riseclassifier\", \n",
    "#                             \"ours\"]\n",
    "explanation_method_to_run=[]\n",
    "explanation_method_to_run+=explanation_method_to_run_[:]\n",
    "\n",
    "\n",
    "print(explanation_method_to_run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "213d97c8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_loader=DataLoader(dset, batch_size=1, shuffle=False, drop_last=False, num_workers=4) #16\n",
    "print(len(dset))\n",
    "print(len(data_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "045b09d1",
   "metadata": {},
   "source": [
    "# 1_classifier_evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4328d78",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"1_classifier_evaluate\":    \n",
    "    classifier_result_list_all={}\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        classifier_result_list_all[backbone_type]={}\n",
    "           \n",
    "    \n",
    "    for idx, batch in enumerate(tqdm(data_loader)):\n",
    "\n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "\n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            # Get classifier output\n",
    "            classifier_dict[backbone_type].eval()\n",
    "            with torch.no_grad():\n",
    "                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),\n",
    "                                                                 output_attentions=True)\n",
    "            if _config[\"output_dim\"]==1:\n",
    "                prob=classifier_output['logits'].sigmoid().cpu().numpy()\n",
    "            else:\n",
    "                prob=classifier_output['logits'].softmax(dim=-1).cpu().numpy()          \n",
    "                \n",
    "                \n",
    "            for path, label, prob in zip(paths, labels, prob):\n",
    "                classifier_result_list_all[backbone_type][path]={'label':label.item(), 'prob':prob.astype(float)}\n",
    "                \n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        classifier_result_list_path=f'results/1_classifier_evaluate/{_config[\"datasets\"]}/{backbone_type}_{dataset_split}.pickle'\n",
    "        with open(classifier_result_list_path, \"wb\") as f:\n",
    "            pickle.dump(classifier_result_list_all[backbone_type], f)        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d237daa",
   "metadata": {},
   "source": [
    "# 2_surrogate_evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e7849b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"2_surrogate_evaluate\":\n",
    "    result_list_all={}\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        result_list_all[backbone_type]=[]\n",
    "\n",
    "    dset_loader=DataLoader(dset, batch_size=64, num_workers=4, shuffle=False, drop_last=True)\n",
    "\n",
    "    for batch_idx, batch in enumerate(tqdm(dset_loader, unit='batch')):  \n",
    "        for num_mask in range(0,196+1,14):\n",
    "            mask=torch.zeros((len(batch[\"images\"]), 196))\n",
    "            mask[:,:num_mask]=1\n",
    "            for i in range(len(mask)):\n",
    "                mask[i]=mask[i][torch.randperm(len(mask[i]))]\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                surrogate_dict[backbone_type][\"original\"].eval()\n",
    "                with torch.no_grad():\n",
    "                    out_original=surrogate_dict[backbone_type][\"original\"](batch[\"images\"].to(surrogate_dict[backbone_type][\"original\"].device),\n",
    "                                                                          torch.ones((len(batch[\"images\"]), 196)).to(surrogate_dict[backbone_type][\"original\"].device))\n",
    "\n",
    "                for mask_location_model in [\"original\" , \"pre-softmax\", \"zero-input\", \"zero-embedding\"]:\n",
    "                    if mask_location_model==\"original\":\n",
    "                        kl_divergence=0\n",
    "\n",
    "                        if _config[\"output_dim\"]==1:\n",
    "                            accuracy=((out_original[\"logits\"].sigmoid()>0.5).cpu().int()==batch['labels']).float().mean().item()\n",
    "                        else:\n",
    "                            accuracy=(torch.argmax(out_original[\"logits\"], dim=1).cpu()==batch['labels']).float().mean().item()\n",
    "\n",
    "                        result_list_all[backbone_type].append({\"batch_idx\": batch_idx,\n",
    "                                            \"backbone_type\": backbone_type,\n",
    "                                            \"num_mask\": num_mask,\n",
    "                                            \"mask_location_model\": mask_location_model,\n",
    "                                            \"mask_location_parameter\": \"original\",\n",
    "                                            \"kl_divergence\": kl_divergence,\n",
    "                                            \"accuracy\": accuracy})\n",
    "\n",
    "                    for mask_location_parameter in [\"pre-softmax\", \"post-softmax\", \"zero-input\", \"zero-embedding\", \"random-sampling\"]:\n",
    "                        surrogate_dict[backbone_type][mask_location_model].eval()\n",
    "                        with torch.no_grad():\n",
    "                            out_surrogate=surrogate_dict[backbone_type][mask_location_model](batch[\"images\"].to(surrogate_dict[backbone_type][mask_location_model].device), \n",
    "                                                                                             mask.to(surrogate_dict[backbone_type][mask_location_model].device),\n",
    "                                                                                             mask_location_parameter)\n",
    "\n",
    "\n",
    "                        if _config[\"output_dim\"]==1:\n",
    "                            kl_divergence = F.kl_div(input=torch.concat([F.logsigmoid(out_surrogate[\"logits\"]), F.logsigmoid(-out_surrogate[\"logits\"])], dim=1),\n",
    "                                                    target=torch.concat([torch.sigmoid(out_original[\"logits\"]), torch.sigmoid(-out_original[\"logits\"])], dim=1),\n",
    "                                                    reduction=\"batchmean\",\n",
    "                                                    log_target=False)                        \n",
    "\n",
    "                        else:\n",
    "                            kl_divergence=F.kl_div(input=torch.log_softmax(out_surrogate[\"logits\"], dim=1),\n",
    "                                                   target=torch.softmax(out_original[\"logits\"], dim=1),\n",
    "                                                   log_target=False,\n",
    "                                                   reduction='batchmean').item()                           \n",
    "\n",
    "                        if _config[\"output_dim\"]==1:\n",
    "                            accuracy=((out_surrogate[\"logits\"].sigmoid()>0.5).cpu().int()==batch['labels']).float().mean().item()\n",
    "                        else:\n",
    "                            accuracy=(torch.argmax(out_surrogate[\"logits\"], dim=1).cpu()==batch['labels']).float().mean().item()\n",
    "\n",
    "\n",
    "                        result_list_all[backbone_type].append({\"batch_idx\": batch_idx,\n",
    "                                            \"backbone_type\": backbone_type,\n",
    "                                            \"num_mask\": num_mask,\n",
    "                                            \"mask_location_model\": mask_location_model,\n",
    "                                            \"mask_location_parameter\": mask_location_parameter,\n",
    "                                            \"kl_divergence\": kl_divergence,\n",
    "                                            \"accuracy\": accuracy})\n",
    "                        \n",
    "                        \n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        result_df=pd.DataFrame(result_list_all[backbone_type])\n",
    "\n",
    "        result_df.to_csv(f'results/4_0_surrogate_evaluate/{_config[\"datasets\"]}/{backbone_type}.csv')                            \n",
    "                        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d352d73b",
   "metadata": {},
   "source": [
    "# 3_explanation_generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb324d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def adapt_path(path_original, path_format):\n",
    "#     path_list=['l0.cs.hostname', 'l1lambda.cs.hostname', 'l2lambda.cs.hostname', 'l3.cs.hostname', 'deeper.cs.hostname']\n",
    "    \n",
    "#     for path1 in path_list:\n",
    "#         if path1 in path_original:\n",
    "#             for path2 in path_list:\n",
    "#                 if path2 in path_format:\n",
    "#                     return path_original.replace(path1, path2)\n",
    "#             raise\n",
    "#     return path_original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe91774",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_explanation(num_players, num_samples=None):\n",
    "    if num_samples is None:\n",
    "        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_players,))\n",
    "    else:\n",
    "        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_samples, num_players))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "439314e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_to_use=['Garbage truck', \n",
    "              'Tench', \n",
    "              'English springer', \n",
    "              'Parachute',  \n",
    "              'Golf ball', \n",
    "              'Gas pump']\n",
    "kernelshap_sample_idx_list_all=[]\n",
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for random_seed in [2, 3, 4, 5]:\n",
    "        label_data_list=np.array([i['label'] for i in dset.data])\n",
    "        kernelshap_sample_idx_list=[np.random.RandomState(random_seed).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)]) for label_idx in [label_name_list.index(label) for label in label_to_use]]\n",
    "        kernelshap_sample_idx_list_all+=kernelshap_sample_idx_list\n",
    "kernelshap_sample_path_list_all=[dset[i]['path'] for i in kernelshap_sample_idx_list_all]        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1478944c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "if evaluation_stage==\"3_explanation_generate\":\n",
    "    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):\n",
    "#         if dataset_split==\"test\":\n",
    "#             if idx>int(1000/data_loader.batch_size+0.5):\n",
    "#                 break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue\n",
    "            \n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "        updated_signal_list=[]\n",
    "        \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            # Get classifier output\n",
    "            classifier_dict[backbone_type].eval()\n",
    "            with torch.no_grad():\n",
    "                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),\n",
    "                                                                 output_attentions=True)        \n",
    "            for explanation_method in explanation_method_to_run:\n",
    "                data_keys=explanation_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if all([path in data_keys for path in paths]):\n",
    "                    continue\n",
    "                else:\n",
    "                    print(explanation_method,'not exist')\n",
    "                    updated_signal_list.append(explanation_method)\n",
    "                if explanation_method==\"random\":\n",
    "                    start_time=time.time()\n",
    "                    explanation_random_list=[get_random_explanation(num_players=196) for path in paths]\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    explanation_save_dict_update(backbone_type, 'random', path_list=paths, explanation_list=explanation_random_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196, ))\n",
    "                \n",
    "                elif explanation_method==\"attention_rollout\":\n",
    "                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)\n",
    "                    start_time=time.time()\n",
    "                    explanation_attention_rollout_list=attentions_to_explanation(attentions, mode='rollout')\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    explanation_save_dict_update(backbone_type, 'attention_rollout', path_list=paths, explanation_list=explanation_attention_rollout_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196,))\n",
    "\n",
    "                elif explanation_method==\"attention_last\":\n",
    "                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)\n",
    "                    start_time=time.time()\n",
    "                    explanation_attention_last_list=attentions_to_explanation(attentions, mode=-1)\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    explanation_save_dict_update(backbone_type, 'attention_last', path_list=paths, explanation_list=explanation_attention_last_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196,))            \n",
    "\n",
    "                elif explanation_method==\"LRP\":\n",
    "                    explanation_lrp_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_lrp_list.append(np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,\n",
    "                                                                              original_image=image.squeeze(0),\n",
    "                                                                              class_index=i,\n",
    "                                                                              mode='transformer_attribution').cpu().numpy() for i in range(_config[\"output_dim\"])], axis=0))\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'LRP', path_list=paths, explanation_list=explanation_lrp_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"gradcam\":\n",
    "                    explanation_gradcam_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_gradcam = np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,\n",
    "                                                                                              original_image=image.squeeze(0),\n",
    "                                                                                              class_index=i,\n",
    "                                                                                              mode='attn_gradcam').cpu().numpy() for i in range(_config[\"output_dim\"])], axis=0)\n",
    "                        explanation_gradcam = np.nan_to_num(explanation_gradcam,nan=0)+np.random.uniform(low=0, high=1e-20, size=explanation_gradcam.shape)                    \n",
    "                        explanation_gradcam_list.append(explanation_gradcam)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'gradcam', path_list=paths, explanation_list=explanation_gradcam_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "\n",
    "                elif explanation_method==\"gradcamgithub\":\n",
    "                    explanation_gradcamgithub_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_gradcamgithub = np.concatenate([cam_dict[backbone_type](input_tensor=image.unsqueeze(0).to(next(cam_dict[backbone_type].model.parameters()).device),\n",
    "                                                                                            targets=[ClassifierOutputTarget(i)], resize=False).flatten()[np.newaxis,:] for i in range(_config[\"output_dim\"])], axis=0)\n",
    "                        explanation_gradcamgithub_list.append(explanation_gradcamgithub)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'gradcamgithub', path_list=paths, explanation_list=explanation_gradcamgithub_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"vanillapixel\":\n",
    "                    explanation_vanillapixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vanilla(image, saliency_pixel=saliency_pixel_dict[backbone_type])\n",
    "                        explanation_vanillapixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'vanillapixel', path_list=paths, explanation_list=explanation_vanillapixel_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"vanillaembedding\":\n",
    "                    explanation_vanillaembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:\n",
    "                        start_time=time.time()\n",
    "                        grad=get_vanilla(image, saliency_embedding=saliency_embedding_dict[backbone_type])\n",
    "                        explanation_vanillaembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'vanillaembedding', path_list=paths, explanation_list=explanation_vanillaembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"sgpixel\":\n",
    "                    explanation_sgpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_sg(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])\n",
    "                        explanation_sgpixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'sgpixel', path_list=paths, explanation_list=explanation_sgpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "\n",
    "                elif explanation_method==\"sgembedding\":\n",
    "                    explanation_sgembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_sg(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])\n",
    "                        explanation_sgembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'sgembedding', path_list=paths, explanation_list=explanation_sgembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                                \n",
    "\n",
    "                elif explanation_method==\"vargradpixel\":\n",
    "                    explanation_vargradpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vargrad(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])\n",
    "                        explanation_vargradpixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'vargradpixel', path_list=paths, explanation_list=explanation_vargradpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"vargradembedding\":\n",
    "                    explanation_vargradembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vargrad(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])\n",
    "                        explanation_vargradembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'vargradembedding', path_list=paths, explanation_list=explanation_vargradembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                                \n",
    "\n",
    "                elif explanation_method==\"igpixel\":\n",
    "                    explanation_igpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_ig(image, ig_pixel=ig_pixel_dict[backbone_type])\n",
    "                        explanation_igpixel_list.append(grad[\"attributions_pixel_patchsum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'igpixel', path_list=paths, explanation_list=explanation_igpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"igembedding\":\n",
    "                    explanation_igembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_ig(image, ig_embedding=ig_embedding_dict[backbone_type])\n",
    "                        explanation_igembedding_list.append(grad[\"attributions_embedding_sum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'igembedding', path_list=paths, explanation_list=explanation_igembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                \n",
    "\n",
    "                elif explanation_method==\"leaveoneoutclassifier\":\n",
    "                    explanation_leaveoneoutclassifier_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_leaveoneoutclassifier = leave_one_out(classifier=classifier_dict_[backbone_type], image=image).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_leaveoneoutclassifier_list.append(explanation_leaveoneoutclassifier)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'leaveoneoutclassifier', path_list=paths, explanation_list=explanation_leaveoneoutclassifier_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"leaveoneoutsurrogate\":\n",
    "                    explanation_leaveoneoutsurrogate_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_leaveoneoutsurrogate = leave_one_out(surrogate=surrogate_dict[backbone_type][\"pre-softmax\"], image=image).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_leaveoneoutsurrogate_list.append(explanation_leaveoneoutsurrogate)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'leaveoneoutsurrogate', path_list=paths, explanation_list=explanation_leaveoneoutsurrogate_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                \n",
    "\n",
    "                elif explanation_method==\"riseclassifier\":\n",
    "                    explanation_riseclassifier_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_riseclassifier = rise(classifier=classifier_dict_[backbone_type], image=image, N=2000, include_prob=0.5).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_riseclassifier_list.append(explanation_riseclassifier)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'riseclassifier', path_list=paths, explanation_list=explanation_riseclassifier_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))\n",
    "\n",
    "                elif explanation_method==\"risesurrogate\":\n",
    "                    explanation_risesurrogate_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_risesurrogate = rise(surrogate=surrogate_dict[backbone_type][\"pre-softmax\"], image=image, N=2000, include_prob=0.5).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_risesurrogate_list.append(explanation_risesurrogate)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'risesurrogate', path_list=paths, explanation_list=explanation_risesurrogate_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                \n",
    "                    \n",
    "                elif explanation_method==\"ours\":\n",
    "                    start_time=time.time()\n",
    "                    explainer_dict[backbone_type].eval()\n",
    "                    with torch.no_grad():\n",
    "                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    explanation_save_dict_update(backbone_type, 'ours', path_list=paths, explanation_list=explanation_ours, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(_config[\"output_dim\"], 196))                                \n",
    "                    \n",
    "                elif explanation_method==\"kernelshap\":                    \n",
    "                    explanation_kernelshap_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    path_list=[]\n",
    "                    surrogate_SHAP_wrapped_dict[backbone_type].eval()\n",
    "                    for path,image in zip(paths, images):\n",
    "                        if path not in kernelshap_sample_path_list_all:\n",
    "                            continue\n",
    "                        print(path)\n",
    "                        start_time=time.time()                        \n",
    "                        explanation_kernelshap_ret = get_shap(surrogate_SHAP_wrapped_dict[backbone_type], image, thresh=0.2)\n",
    "                        explanation_kernelshap = explanation_kernelshap_ret.values.T\n",
    "                        explanation_kernelshap_list.append(explanation_kernelshap)\n",
    "                        path_list.append(path)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    explanation_save_dict_update(backbone_type, 'kernelshap', path_list=path_list, explanation_list=explanation_kernelshap_list, elapsed_time_list=elapsed_time_list, shape=(_config[\"output_dim\"], 196))                    \n",
    "                else:\n",
    "                    raise\n",
    "\n",
    "        try:\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():\n",
    "                    if explanation_method not in updated_signal_list:\n",
    "                        continue\n",
    "                    explanation_save_dict_path=f'results/3_explanation_generate/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                    if os.path.isfile(explanation_save_dict_path):\n",
    "                        try:\n",
    "                            with open(explanation_save_dict_path, 'rb') as f:\n",
    "                                explanation_save_dict_loaded=pickle.load(f)\n",
    "                        except:\n",
    "                            explanation_save_dict_loaded={}\n",
    "                    else:\n",
    "                        explanation_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(explanation_save_dict_backbone_method)            \n",
    "                    len_loaded=len(explanation_save_dict_loaded)\n",
    "                    explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)\n",
    "                    len_updated=len(explanation_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "                    with open(explanation_save_dict_path, \"wb\") as f:\n",
    "                        pickle.dump(explanation_save_dict_backbone_method, f)\n",
    "        except:\n",
    "            pass\n",
    "            \n",
    "            \n",
    "            \n",
    "    #         # Get ours\n",
    "    #         start_time=time.time();explainer_dict[backbone_type].eval()\n",
    "    #         with torch.no_grad():\n",
    "    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))\n",
    "    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)\n",
    "    #         values_ours=(values, time.time()-start_time)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41b8d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"3_explanation_generate\":\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():\n",
    "            explanation_save_dict_path=f'results/3_explanation_generate/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "            if os.path.isfile(explanation_save_dict_path):\n",
    "                with open(explanation_save_dict_path, 'rb') as f:\n",
    "                    explanation_save_dict_loaded=pickle.load(f)\n",
    "            else:\n",
    "                explanation_save_dict_loaded={}\n",
    "\n",
    "            len_original=len(explanation_save_dict_backbone_method)            \n",
    "            len_loaded=len(explanation_save_dict_loaded)\n",
    "            explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)\n",
    "            len_updated=len(explanation_save_dict_backbone_method)\n",
    "\n",
    "            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "            with open(explanation_save_dict_path, \"wb\") as f:\n",
    "                pickle.dump(explanation_save_dict_backbone_method, f)       "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfcab1c5",
   "metadata": {},
   "source": [
    "# 4_insert_delete"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c78578",
   "metadata": {},
   "outputs": [],
   "source": [
    "def explanation_to_mask(explanation, mode='insertion'):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        explanation: (num_batches, num_players)\n",
    "    Returns:\n",
    "        explanation_expaned_bool: (num_batches, num_players+1, num_players)\n",
    "    \"\"\"\n",
    "    \n",
    "    explanation_expaned=np.repeat(explanation[:,np.newaxis,:], explanation.shape[-1], axis=1) # (num_batches, num_players, num_players)\n",
    "    \n",
    "    if mode=='insertion':\n",
    "        explanation_expaned_bool = explanation_expaned > ((np.sort(explanation, axis=-1)[:, : :-1])[:, :, np.newaxis]) # (num_batches, num_players, num_players)\n",
    "        explanation_expaned_bool = np.concatenate([explanation_expaned_bool,\n",
    "                                                   np.ones(shape=(explanation_expaned_bool.shape[0], 1, explanation_expaned_bool.shape[2]))==1], axis=1) # (num_batches, num_players+1, num_players)\n",
    "        #print(explanation_expaned_bool.shape)\n",
    "    elif mode=='deletion':\n",
    "        explanation_expaned_bool = explanation_expaned < ((np.sort(explanation, axis=-1)[:, : :-1])[:, :, np.newaxis]) # (num_batches, num_players, num_players)\n",
    "        explanation_expaned_bool = np.concatenate([np.ones(shape=(explanation_expaned_bool.shape[0], 1, explanation_expaned_bool.shape[2]))==1,\n",
    "                                                   explanation_expaned_bool],axis=1) # (num_batches, num_players+1, num_players)        \n",
    "        \n",
    "    else:\n",
    "        raise ValueError(f'{mode} should be insertion or deletion.')\n",
    "    \n",
    "    \n",
    "    return explanation_expaned_bool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "835e5d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_stage=\"4_insert_delete\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "161dc89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_method_to_run=[\"kernelshap\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e30e33",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimationerror_sample_path_list=pd.DataFrame(data_loader.dataset.data).groupby(\"label\").apply(lambda x: x.sample(n=10, random_state=42))[\"img_path\"].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f37cb8db",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80a1e097",
   "metadata": {},
   "outputs": [],
   "source": [
    "parallel_mode=(0,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2b8727f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if evaluation_stage==\"4_insert_delete\":\n",
    "    num_players=196\n",
    "    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" and (len(explanation_method_to_run)!=1 or explanation_method_to_run[0]!=\"kernelshap\") else None)):\n",
    "        if dataset_split==\"test\" and (len(explanation_method_to_run)!=1 or explanation_method_to_run[0]!=\"kernelshap\"):\n",
    "            if idx>int(1000/data_loader.batch_size+0.5):\n",
    "                break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue                                \n",
    "\n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "\n",
    "        updated_signal_list=[]\n",
    "\n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "\n",
    "            for explanation_method in explanation_method_to_run:\n",
    "                data_keys=insertdelete_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if explanation_method=='kernelshap':\n",
    "                    if all([(path in data_keys) or (path not in estimationerror_sample_path_list) for path in paths]):\n",
    "                        continue\n",
    "                    else:\n",
    "                        print(explanation_method,'not exist')\n",
    "                        updated_signal_list.append(explanation_method)                \n",
    "                else:\n",
    "                    if all([path in data_keys for path in paths]):\n",
    "                        continue\n",
    "                    else:\n",
    "                        print(explanation_method,'not exist')\n",
    "                        updated_signal_list.append(explanation_method)                      \n",
    "                \n",
    "                if explanation_method==\"random\":\n",
    "                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]\n",
    "                else:\n",
    "                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']\n",
    "                                  for path in paths]\n",
    "                \n",
    "                insertdelete_dict={'insertion': [], 'deletion': []}\n",
    "                for image, explanation in zip(images, explanations):\n",
    "                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)\n",
    "                    if np.isnan(explanation).any():\n",
    "                        print(explanation_method, \"Null found\")\n",
    "                      \n",
    "                    if explanation_method==\"random\":\n",
    "                        for metric_mode in insertdelete_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)\n",
    "                            prob_=[]\n",
    "                            for random_iter in range(mask.shape[0]):\n",
    "                                surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                                with torch.no_grad():\n",
    "                                    output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                          masks=torch.Tensor(mask[random_iter]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "                                if _config[\"output_dim\"]==1:\n",
    "                                    prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                                else:\n",
    "                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                                prob_.append(prob.T)\n",
    "                            prob=np.array(prob_) # (10, num_players+1, num_classes)\n",
    "                            insertdelete_dict[metric_mode].append(prob)\n",
    "                        \n",
    "                    elif len(explanation.shape)==1:\n",
    "                        for metric_mode in insertdelete_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)\n",
    "                            surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                            prob=prob.T # ( , num_players)\n",
    "                            insertdelete_dict[metric_mode].append(prob) \n",
    "\n",
    "                    elif len(explanation.shape)==2:\n",
    "                        for metric_mode in insertdelete_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)\n",
    "                            prob_=[]\n",
    "                            for class_idx in range(mask.shape[0]):\n",
    "                                surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                                with torch.no_grad():\n",
    "                                    output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                          masks=torch.Tensor(mask[class_idx]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))                                  \n",
    "                                if _config[\"output_dim\"]==1:\n",
    "                                    prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                                else:\n",
    "                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                                prob_.append(prob[:, class_idx])\n",
    "                            prob=np.array(prob_)\n",
    "                            insertdelete_dict[metric_mode].append(prob)\n",
    "                    else:\n",
    "                        raise\n",
    "                            \n",
    "                if explanation_method==\"random\":\n",
    "                    insertdelete_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=insertdelete_dict[\"insertion\"], \n",
    "                                                  delete_list=insertdelete_dict[\"deletion\"], \n",
    "                                                  shape=(10, _config[\"output_dim\"], num_players+1))                       \n",
    "                elif len(explanations[0].shape)==1:\n",
    "                    insertdelete_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=insertdelete_dict[\"insertion\"], \n",
    "                                                  delete_list=insertdelete_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))                    \n",
    "                elif len(explanations[0].shape)==2:\n",
    "                    insertdelete_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=insertdelete_dict[\"insertion\"], \n",
    "                                                  delete_list=insertdelete_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))\n",
    "                else:\n",
    "                    raise\n",
    "        \n",
    "        \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():\n",
    "                \n",
    "                if explanation_method not in updated_signal_list:\n",
    "                    continue                \n",
    "                \n",
    "                insertdelete_save_dict_path=f'results/4_insert_delete/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                if os.path.isfile(insertdelete_save_dict_path):\n",
    "                    with open(insertdelete_save_dict_path, 'rb') as f:\n",
    "                        insertdelete_save_dict_loaded=pickle.load(f)\n",
    "                else:\n",
    "                    insertdelete_save_dict_loaded={}\n",
    "\n",
    "                len_original=len(insertdelete_save_dict_backbone_method)            \n",
    "                len_loaded=len(insertdelete_save_dict_loaded)\n",
    "                insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)\n",
    "                len_updated=len(insertdelete_save_dict_backbone_method)\n",
    "\n",
    "                print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "\n",
    "                with open(insertdelete_save_dict_path, \"wb\") as f:\n",
    "                    pickle.dump(insertdelete_save_dict_backbone_method, f)           \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecfa8ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"4_insert_delete\":\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "        for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():\n",
    "            insertdelete_save_dict_path=f'results/4_insert_delete/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "            if os.path.isfile(insertdelete_save_dict_path):\n",
    "                with open(insertdelete_save_dict_path, 'rb') as f:\n",
    "                    insertdelete_save_dict_loaded=pickle.load(f)\n",
    "            else:\n",
    "                insertdelete_save_dict_loaded={}\n",
    "\n",
    "            len_original=len(insertdelete_save_dict_backbone_method)            \n",
    "            len_loaded=len(insertdelete_save_dict_loaded)\n",
    "            insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)\n",
    "            len_updated=len(insertdelete_save_dict_backbone_method)\n",
    "\n",
    "            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "\n",
    "            with open(insertdelete_save_dict_path, \"wb\") as f:\n",
    "                pickle.dump(insertdelete_save_dict_backbone_method, f)        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e19f433b",
   "metadata": {},
   "source": [
    "# 5_sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "415b792b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,\n",
    "                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        num_players: the number of players in the coalitional game\n",
    "        num_mask_samples: the number of masks to generate\n",
    "        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.\n",
    "        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')\n",
    "        random_state: random generator\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor of shape\n",
    "        (num_masks, num_players) if num_masks is int\n",
    "        (num_players) if num_masks is None\n",
    "\n",
    "    \"\"\"\n",
    "    random_state = random_state or np.random\n",
    "\n",
    "    num_samples_ = num_mask_samples or 1\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        assert num_samples_ % 2 == 0, \"'num_samples' must be a multiple of 2 if 'paired' is True\"\n",
    "        num_samples_ = num_samples_ // 2\n",
    "    else:\n",
    "        num_samples_ = num_samples_\n",
    "\n",
    "    if mode == 'uniform':\n",
    "        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')\n",
    "    elif mode == 'shapley':\n",
    "        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))\n",
    "        probs = probs / probs.sum()\n",
    "        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(\n",
    "            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')\n",
    "    else:\n",
    "        raise ValueError(\"'mode' must be 'random' or 'shapley'\")\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)\n",
    "\n",
    "    if num_mask_samples is None:\n",
    "        masks = masks.squeeze(0)\n",
    "        return masks  # (num_masks)\n",
    "    else:\n",
    "        return masks  # (num_samples, num_masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8a4bf39",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if evaluation_stage==\"5_sensitivity\":\n",
    "    num_players=196\n",
    "    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):   \n",
    "        if dataset_split==\"test\":\n",
    "            if idx>int(1000/data_loader.batch_size+0.5):\n",
    "                break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue                                \n",
    "\n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "\n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):            \n",
    "            for image, path in zip(images, paths):\n",
    "                #path=path.replace('l0.cs.hostname','l2lambda.cs.hostname')\n",
    "                for num_included_players in [\"all\"] + list(range(14, 196, 14)):\n",
    "                    if num_included_players==\"all\":\n",
    "                        prob_all=[]\n",
    "                        mask_all=[]\n",
    "                        for random_iter in range(20):\n",
    "                            image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)\n",
    "                            mask=generate_mask(num_players=num_players,\n",
    "                                               num_mask_samples=50,\n",
    "                                               paired_mask_samples=False,\n",
    "                                               mode=\"uniform\",\n",
    "                                               random_state=np.random.RandomState(random_iter))\n",
    "\n",
    "                            surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()   \n",
    "                            prob_all.append(prob)\n",
    "                            mask_all.append(mask)\n",
    "                        prob_all=np.concatenate(prob_all, axis=0)\n",
    "                        mask_all=np.concatenate(mask_all, axis=0)\n",
    "                    else:\n",
    "                        prob_all=[]\n",
    "                        mask_all=[]\n",
    "                        for random_iter in range(20):\n",
    "                            image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)                                                        \n",
    "                            mask=np.zeros((50, num_players))\n",
    "                            mask[:, :num_included_players]=1\n",
    "                            for i in range(len(mask)):\n",
    "                                mask[i]=np.random.RandomState(42+10*random_iter+i).permutation(mask[i])\n",
    "\n",
    "                            surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()   \n",
    "                            prob_all.append(prob)\n",
    "                            mask_all.append(mask)\n",
    "                        prob_all=np.concatenate(prob_all, axis=0)\n",
    "                        mask_all=np.concatenate(mask_all, axis=0)\n",
    "                    \n",
    "                    for explanation_method in explanation_method_to_run:                \n",
    "                        if explanation_method==\"random\":\n",
    "                            continue\n",
    "                        explanation=explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys())[0])]['explanation']\n",
    "                        explanation=explanation+np.random.RandomState(42).uniform(low=0, high=1e-40, size=explanation.shape)\n",
    "                        explanation_mask=explanation@(mask_all.T)\n",
    "\n",
    "                        if len(explanation.shape)==1:\n",
    "                            correlation = np.array([stats.spearmanr(explanation_mask, prob_all[:, i]).correlation for i in range(prob_all.shape[1])])\n",
    "                            assert correlation.shape==(_config[\"output_dim\"],)\n",
    "                            \n",
    "                        elif len(explanation.shape)==2:\n",
    "                            #correlation=stats.spearmanr(np.concatenate([explanation_mask, prob_all.T], axis=0), axis=1).correlation\n",
    "                            correlation = np.array([stats.spearmanr(explanation_mask[i], prob_all[:, i]).correlation for i in range(prob_all.shape[1])])\n",
    "                            assert correlation.shape==(_config[\"output_dim\"],)\n",
    "                        else:\n",
    "                            raise\n",
    "                        sensitivity_save_dit_update(backbone_type, explanation_method,\n",
    "                                                     num_included_players=num_included_players,\n",
    "                                                     path_list=[path], sensitivity_list=[correlation],\n",
    "                                                     shape=(_config[\"output_dim\"],))     \n",
    "                    \n",
    "                    \n",
    "                \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            for explanation_method, sensitivity_save_dit_backbone_method in sensitivity_save_dit[backbone_type].items():\n",
    "                sensitivity_save_dit_path=f'results/5_sensitivity/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                if os.path.isfile(sensitivity_save_dit_path):\n",
    "                    with open(sensitivity_save_dit_path, 'rb') as f:\n",
    "                        sensitivity_save_dit_loaded=pickle.load(f)\n",
    "                else:\n",
    "                    sensitivity_save_dit_loaded={}\n",
    "\n",
    "                len_original=len(sensitivity_save_dit_backbone_method)            \n",
    "                len_loaded=len(sensitivity_save_dit_loaded)\n",
    "                sensitivity_save_dit_backbone_method.update(sensitivity_save_dit_loaded)\n",
    "                len_updated=len(sensitivity_save_dit_backbone_method)\n",
    "\n",
    "                print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "\n",
    "                with open(sensitivity_save_dit_path, \"wb\") as f:\n",
    "                    pickle.dump(sensitivity_save_dit_backbone_method, f)        \n",
    "                    \n",
    "             \n",
    "                "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89881f09",
   "metadata": {},
   "source": [
    "# 6_noretraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cca77a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"6_noretraining\":\n",
    "    num_players=196\n",
    "    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):   \n",
    "        if dataset_split==\"test\":\n",
    "            if idx>int(1000/data_loader.batch_size+0.5):\n",
    "                break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue                                \n",
    "\n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "        updated_signal_list=[]\n",
    "\n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "\n",
    "            for explanation_method in [\"random\"]+explanation_method_to_run:\n",
    "                data_keys=noretraining_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if all([path in data_keys for path in paths]):\n",
    "                    continue\n",
    "                else:\n",
    "                    print(explanation_method,'not exist')\n",
    "                    updated_signal_list.append(explanation_method)                \n",
    "                \n",
    "                if explanation_method==\"random\":\n",
    "                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]\n",
    "                else:\n",
    "                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']\n",
    "                                  for path in paths]\n",
    "                noretraining_dict={'insertion': [], 'deletion': []}\n",
    "                for image, explanation, label in zip(images, explanations, labels):\n",
    "                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)\n",
    "                    if np.isnan(explanation).any():\n",
    "                        print(explanation_method, \"Null found\")\n",
    "                      \n",
    "                    if explanation_method==\"random\":\n",
    "                        for metric_mode in noretraining_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)\n",
    "                            prob_=[]\n",
    "                            for random_iter in range(mask.shape[0]):\n",
    "                                surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                                with torch.no_grad():\n",
    "                                    output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                          masks=torch.Tensor(mask[random_iter]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "                                if _config[\"output_dim\"]==1:\n",
    "                                    prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                                else:\n",
    "                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                                prob_.append(prob.T)\n",
    "                            prob=np.array(prob_) # (10, num_classes, num_players+1)\n",
    "                            noretraining_dict[metric_mode].append(prob)#(prob.argmax(axis=1)==label.item()).astype(float))\n",
    "                        \n",
    "                    elif len(explanation.shape)==1:\n",
    "                        for metric_mode in noretraining_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)\n",
    "                            surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                            prob=prob.T # (num_classes, num_players+1)                            \n",
    "                            #noretraining_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float)) \n",
    "                            noretraining_dict[metric_mode].append(prob) \n",
    "\n",
    "                    elif len(explanation.shape)==2:\n",
    "                        for metric_mode in noretraining_dict.keys():\n",
    "                            #mask=explanation_to_mask(explanation=explanation[label.item()], mode=metric_mode)\n",
    "                            if len(explanation)==1:\n",
    "                                mask=explanation_to_mask(explanation=get_relative_value(explanation[0])[np.newaxis,:], mode=metric_mode)\n",
    "                            else:\n",
    "                                mask=explanation_to_mask(explanation=get_relative_value(explanation[label.item()])[np.newaxis,:], mode=metric_mode)\n",
    "                            \n",
    "                            surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                            prob=prob.T\n",
    "                            #noretraining_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float))\n",
    "                            noretraining_dict[metric_mode].append(prob)\n",
    "                    else:\n",
    "                        raise\n",
    "                          \n",
    "                if explanation_method==\"random\":\n",
    "                    noretraining_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=noretraining_dict[\"insertion\"], \n",
    "                                                  delete_list=noretraining_dict[\"deletion\"], \n",
    "                                                  shape=(10, _config[\"output_dim\"], num_players+1))                       \n",
    "                elif len(explanations[0].shape)==1:\n",
    "                    noretraining_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=noretraining_dict[\"insertion\"], \n",
    "                                                  delete_list=noretraining_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))                    \n",
    "                elif len(explanations[0].shape)==2:\n",
    "                    noretraining_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=noretraining_dict[\"insertion\"], \n",
    "                                                  delete_list=noretraining_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))\n",
    "                else:\n",
    "                    raise\n",
    "        \n",
    "        try:        \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                for explanation_method, noretraining_save_dict_backbone_method in noretraining_save_dict[backbone_type].items():\n",
    "\n",
    "                    if explanation_method not in updated_signal_list:\n",
    "                        continue                \n",
    "\n",
    "                    noretraining_save_dict_path=f'results/6_noretraining/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                    if os.path.isfile(noretraining_save_dict_path):\n",
    "                        try:\n",
    "                            with open(noretraining_save_dict_path, 'rb') as f:\n",
    "                                noretraining_save_dict_loaded=pickle.load(f)\n",
    "                        except:\n",
    "                            noretraining_save_dict_loaded={}\n",
    "                    else:\n",
    "                        noretraining_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(noretraining_save_dict_backbone_method)            \n",
    "                    len_loaded=len(noretraining_save_dict_loaded)\n",
    "                    noretraining_save_dict_backbone_method.update(noretraining_save_dict_loaded)\n",
    "                    len_updated=len(noretraining_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "\n",
    "                    with open(noretraining_save_dict_path, \"wb\") as f:\n",
    "                        pickle.dump(noretraining_save_dict_backbone_method, f)\n",
    "        except:\n",
    "            pass\n",
    "        \n",
    "                            \n",
    "                    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db8f74b5",
   "metadata": {},
   "source": [
    "# 7_classifiermasked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ff594c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if evaluation_stage==\"7_classifiermasked\":\n",
    "    num_players=196\n",
    "    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):   \n",
    "        if dataset_split==\"test\":\n",
    "            if idx>int(1000/data_loader.batch_size+0.5):\n",
    "                break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue                                \n",
    "\n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "        updated_signal_list=[]\n",
    "\n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "\n",
    "            for explanation_method in [\"random\"]+explanation_method_to_run:\n",
    "                data_keys=classifiermasked_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if all([path in data_keys for path in paths]):\n",
    "                    continue\n",
    "                else:\n",
    "                    print(explanation_method,'not exist')\n",
    "                    updated_signal_list.append(explanation_method)                \n",
    "                \n",
    "                if explanation_method==\"random\":\n",
    "                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]\n",
    "                else:\n",
    "                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']\n",
    "                                  for path in paths]\n",
    "                classifiermasked_dict={'insertion': [], 'deletion': []}\n",
    "                for image, explanation, label in zip(images, explanations, labels):\n",
    "                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(classifier_masked_dict[backbone_type].device)\n",
    "                    if np.isnan(explanation).any():\n",
    "                        print(explanation_method, \"Null found\")\n",
    "                      \n",
    "                    if explanation_method==\"random\":\n",
    "                        for metric_mode in classifiermasked_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)\n",
    "                            prob_=[]\n",
    "                            for random_iter in range(mask.shape[0]):\n",
    "                                classifier_masked_dict[backbone_type].eval()\n",
    "                                with torch.no_grad():\n",
    "                                    output = classifier_masked_dict[backbone_type](image_loaded,\n",
    "                                                                                          masks=torch.Tensor(mask[random_iter]).to(classifier_masked_dict[backbone_type].device))\n",
    "                                if _config[\"output_dim\"]==1:\n",
    "                                    prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                                else:\n",
    "                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                                prob_.append(prob.T)\n",
    "                            prob=np.array(prob_) # (10, num_classes, num_players+1)\n",
    "                            classifiermasked_dict[metric_mode].append(prob)#(prob.argmax(axis=1)==label.item()).astype(float))\n",
    "                        \n",
    "                    elif len(explanation.shape)==1:\n",
    "                        for metric_mode in classifiermasked_dict.keys():\n",
    "                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)\n",
    "                            classifier_masked_dict[backbone_type].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = classifier_masked_dict[backbone_type](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask[0]).to(classifier_masked_dict[backbone_type].device))\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                            prob=prob.T # (num_classes, num_players+1)                            \n",
    "                            #classifiermasked_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float)) \n",
    "                            classifiermasked_dict[metric_mode].append(prob) \n",
    "\n",
    "                    elif len(explanation.shape)==2:\n",
    "                        for metric_mode in classifiermasked_dict.keys():\n",
    "                            #mask=explanation_to_mask(explanation=explanation[label.item()], mode=metric_mode)\n",
    "                            if len(explanation)==1:\n",
    "                                mask=explanation_to_mask(explanation=get_relative_value(explanation[0])[np.newaxis,:], mode=metric_mode)\n",
    "                            else:\n",
    "                                mask=explanation_to_mask(explanation=get_relative_value(explanation[label.item()])[np.newaxis,:], mode=metric_mode)\n",
    "                            \n",
    "                            classifier_masked_dict[backbone_type].eval()\n",
    "                            with torch.no_grad():\n",
    "                                output = classifier_masked_dict[backbone_type](image_loaded,\n",
    "                                                                                      masks=torch.Tensor(mask[0]).to(classifier_masked_dict[backbone_type].device))\n",
    "                            if _config[\"output_dim\"]==1:\n",
    "                                prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                            else:\n",
    "                                prob=output['logits'].softmax(dim=-1).cpu().numpy()\n",
    "                            prob=prob.T\n",
    "                            #classifiermasked_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float))\n",
    "                            classifiermasked_dict[metric_mode].append(prob)\n",
    "                    else:\n",
    "                        raise\n",
    "                          \n",
    "                if explanation_method==\"random\":\n",
    "                    classifiermasked_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=classifiermasked_dict[\"insertion\"], \n",
    "                                                  delete_list=classifiermasked_dict[\"deletion\"], \n",
    "                                                  shape=(10, _config[\"output_dim\"], num_players+1))                       \n",
    "                elif len(explanations[0].shape)==1:\n",
    "                    classifiermasked_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=classifiermasked_dict[\"insertion\"], \n",
    "                                                  delete_list=classifiermasked_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))                    \n",
    "                elif len(explanations[0].shape)==2:\n",
    "                    classifiermasked_save_dict_update(backbone_type, explanation_method,\n",
    "                                                  path_list=paths, \n",
    "                                                  insert_list=classifiermasked_dict[\"insertion\"], \n",
    "                                                  delete_list=classifiermasked_dict[\"deletion\"], \n",
    "                                                  shape=(_config[\"output_dim\"], num_players+1))\n",
    "                else:\n",
    "                    raise\n",
    "        \n",
    "        try:        \n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                for explanation_method, classifiermasked_save_dict_backbone_method in classifiermasked_save_dict[backbone_type].items():\n",
    "\n",
    "                    if explanation_method not in updated_signal_list:\n",
    "                        continue                \n",
    "\n",
    "                    classifiermasked_save_dict_path=f'results/7_classifiermasked/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                    if os.path.isfile(classifiermasked_save_dict_path):\n",
    "                        try:\n",
    "                            with open(classifiermasked_save_dict_path, 'rb') as f:\n",
    "                                classifiermasked_save_dict_loaded=pickle.load(f)\n",
    "                        except:\n",
    "                            classifiermasked_save_dict_loaded={}\n",
    "                    else:\n",
    "                        classifiermasked_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(classifiermasked_save_dict_backbone_method)            \n",
    "                    len_loaded=len(classifiermasked_save_dict_loaded)\n",
    "                    classifiermasked_save_dict_backbone_method.update(classifiermasked_save_dict_loaded)\n",
    "                    len_updated=len(classifiermasked_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "\n",
    "                    with open(classifiermasked_save_dict_path, \"wb\") as f:\n",
    "                        pickle.dump(classifiermasked_save_dict_backbone_method, f)\n",
    "        except:\n",
    "            pass\n",
    "        \n",
    "                            \n",
    "                    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fedcf947",
   "metadata": {},
   "source": [
    "# 8_elapsedtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5e1873",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def get_random_explanation(num_players, num_samples=None):\n",
    "    if num_samples is None:\n",
    "        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_players,))\n",
    "    else:\n",
    "        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_samples, num_players))\n",
    "\n",
    "\n",
    "if evaluation_stage==\"8_elapsedtime\":\n",
    "    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):\n",
    "        if dataset_split==\"test\":\n",
    "            if idx>int(1000/data_loader.batch_size+0.5):\n",
    "                break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue\n",
    "            \n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "        updated_signal_list=[]\n",
    "        \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            # Get classifier output\n",
    "            classifier_dict[backbone_type].eval()\n",
    "            with torch.no_grad():\n",
    "                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),\n",
    "                                                                 output_attentions=True)        \n",
    "            for explanation_method in explanation_method_to_run:\n",
    "                data_keys=elapsedtime_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if all([path in data_keys for path in paths]):\n",
    "                    continue\n",
    "                else:\n",
    "                    print(explanation_method,'not exist')\n",
    "                    updated_signal_list.append(explanation_method)\n",
    "                    \n",
    "                if explanation_method==\"random\":\n",
    "                    start_time=time.time()\n",
    "                    explanation_random_list=[get_random_explanation(num_players=196) for path in paths]\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'random', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])\n",
    "                \n",
    "                elif explanation_method==\"attention_rollout\":\n",
    "                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)\n",
    "                    start_time=time.time()\n",
    "                    explanation_attention_rollout_list=attentions_to_explanation(attentions, mode='rollout')\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'attention_rollout', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])\n",
    "\n",
    "                elif explanation_method==\"attention_last\":\n",
    "                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)\n",
    "                    start_time=time.time()\n",
    "                    explanation_attention_last_list=attentions_to_explanation(attentions, mode=-1)\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'attention_last', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])            \n",
    "\n",
    "                elif explanation_method==\"LRP\":\n",
    "                    explanation_lrp_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_lrp_list.append(np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,\n",
    "                                                                              original_image=image.squeeze(0),\n",
    "                                                                              class_index=i,\n",
    "                                                                              mode='transformer_attribution').cpu().numpy() for i in range(_config[\"output_dim\"])], axis=0))\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'LRP', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"gradcam\":\n",
    "                    explanation_gradcam_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_gradcam = np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,\n",
    "                                                                                              original_image=image.squeeze(0),\n",
    "                                                                                              class_index=i,\n",
    "                                                                                              mode='attn_gradcam').cpu().numpy() for i in range(_config[\"output_dim\"])], axis=0)\n",
    "                        explanation_gradcam = np.nan_to_num(explanation_gradcam,nan=0)+np.random.uniform(low=0, high=1e-20, size=explanation_gradcam.shape)                    \n",
    "                        explanation_gradcam_list.append(explanation_gradcam)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'gradcam', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "\n",
    "                elif explanation_method==\"gradcamgithub\":\n",
    "                    explanation_gradcamgithub_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_gradcamgithub = np.concatenate([cam_dict[backbone_type](input_tensor=image.unsqueeze(0).to(next(cam_dict[backbone_type].model.parameters()).device),\n",
    "                                                                                            targets=[ClassifierOutputTarget(i)], resize=False).flatten()[np.newaxis,:] for i in range(_config[\"output_dim\"])], axis=0)\n",
    "                        explanation_gradcamgithub_list.append(explanation_gradcamgithub)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'gradcamgithub', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"vanillapixel\":\n",
    "                    explanation_vanillapixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vanilla(image, saliency_pixel=saliency_pixel_dict[backbone_type])\n",
    "                        explanation_vanillapixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'vanillapixel', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"vanillaembedding\":\n",
    "                    explanation_vanillaembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:\n",
    "                        start_time=time.time()\n",
    "                        grad=get_vanilla(image, saliency_embedding=saliency_embedding_dict[backbone_type])\n",
    "                        explanation_vanillaembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'vanillaembedding', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"sgpixel\":\n",
    "                    explanation_sgpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_sg(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])\n",
    "                        explanation_sgpixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'sgpixel', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "\n",
    "                elif explanation_method==\"sgembedding\":\n",
    "                    explanation_sgembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_sg(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])\n",
    "                        explanation_sgembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'sgembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                                \n",
    "\n",
    "                elif explanation_method==\"vargradpixel\":\n",
    "                    explanation_vargradpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vargrad(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])\n",
    "                        explanation_vargradpixel_list.append(grad[\"attributions_pixel_patchabssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'vargradpixel', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"vargradembedding\":\n",
    "                    explanation_vargradembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_vargrad(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])\n",
    "                        explanation_vargradembedding_list.append(grad[\"attributions_embedding_abssum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'vargradembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                                \n",
    "\n",
    "                elif explanation_method==\"igpixel\":\n",
    "                    explanation_igpixel_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_ig(image, ig_pixel=ig_pixel_dict[backbone_type])\n",
    "                        explanation_igpixel_list.append(grad[\"attributions_pixel_patchsum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'igpixel', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"igembedding\":\n",
    "                    explanation_igembedding_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        grad=get_ig(image, ig_embedding=ig_embedding_dict[backbone_type])\n",
    "                        explanation_igembedding_list.append(grad[\"attributions_embedding_sum\"].reshape(_config[\"output_dim\"],196).numpy())\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'igembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                \n",
    "\n",
    "                elif explanation_method==\"leaveoneoutclassifier\":\n",
    "                    explanation_leaveoneoutclassifier_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_leaveoneoutclassifier = leave_one_out(classifier=classifier_dict_[backbone_type], image=image).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_leaveoneoutclassifier_list.append(explanation_leaveoneoutclassifier)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'leaveoneoutclassifier', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"leaveoneoutsurrogate\":\n",
    "                    explanation_leaveoneoutsurrogate_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_leaveoneoutsurrogate = leave_one_out(surrogate=surrogate_dict[backbone_type][\"pre-softmax\"], image=image).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_leaveoneoutsurrogate_list.append(explanation_leaveoneoutsurrogate)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'leaveoneoutsurrogate', path_list=paths, elapsed_time_list=elapsed_time_list)                \n",
    "\n",
    "                elif explanation_method==\"riseclassifier\":\n",
    "                    explanation_riseclassifier_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_riseclassifier = rise(classifier=classifier_dict_[backbone_type], image=image, N=2000, include_prob=0.5).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_riseclassifier_list.append(explanation_riseclassifier)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'riseclassifier', path_list=paths, elapsed_time_list=elapsed_time_list)\n",
    "\n",
    "                elif explanation_method==\"risesurrogate\":\n",
    "                    explanation_risesurrogate_list=[]\n",
    "                    elapsed_time_list=[]\n",
    "                    for image in images:               \n",
    "                        start_time=time.time()\n",
    "                        explanation_risesurrogate = rise(surrogate=surrogate_dict[backbone_type][\"pre-softmax\"], image=image, N=2000, include_prob=0.5).reshape(_config[\"output_dim\"], 196)\n",
    "                        explanation_risesurrogate_list.append(explanation_risesurrogate)\n",
    "                        elapsed_time_list.append(time.time()-start_time)\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'risesurrogate', path_list=paths, elapsed_time_list=elapsed_time_list)                \n",
    "                    \n",
    "                elif explanation_method==\"ours\":\n",
    "                    start_time=time.time()\n",
    "                    explainer_dict[backbone_type].eval()\n",
    "                    with torch.no_grad():\n",
    "                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    elapsedtime_save_dict_update(backbone_type, 'ours', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])                                \n",
    "                else:\n",
    "                    raise\n",
    "\n",
    "        try:\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():\n",
    "                    if explanation_method not in updated_signal_list:\n",
    "                        continue\n",
    "                    elapsedtime_save_dict_path=f'results/8_elapsedtime/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                    if os.path.isfile(elapsedtime_save_dict_path):\n",
    "                        try:\n",
    "                            with open(elapsedtime_save_dict_path, 'rb') as f:\n",
    "                                elapsedtime_save_dict_loaded=pickle.load(f)\n",
    "                        except:\n",
    "                            elapsedtime_save_dict_loaded={}\n",
    "                    else:\n",
    "                        elapsedtime_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(elapsedtime_save_dict_backbone_method)            \n",
    "                    len_loaded=len(elapsedtime_save_dict_loaded)\n",
    "                    elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)\n",
    "                    len_updated=len(elapsedtime_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "\n",
    "                    with open(elapsedtime_save_dict_path, \"wb\") as f:\n",
    "                        pickle.dump(elapsedtime_save_dict_backbone_method, f)\n",
    "        except:\n",
    "            pass\n",
    "            \n",
    "            \n",
    "            \n",
    "    #         # Get ours\n",
    "    #         start_time=time.time();explainer_dict[backbone_type].eval()\n",
    "    #         with torch.no_grad():\n",
    "    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))\n",
    "    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)\n",
    "    #         values_ours=(values, time.time()-start_time)                "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68fe2ef4",
   "metadata": {},
   "source": [
    "# 9_estimationerror"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f3afb49",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimationerror_sample_path_list=pd.DataFrame(data_loader.dataset.data).groupby(\"label\").apply(lambda x: x.sample(n=10, random_state=42))[\"img_path\"].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9d02b8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if evaluation_stage==\"9_estimationerror\":\n",
    "    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):\n",
    "#         if dataset_split==\"test\":\n",
    "#             if idx>int(1000/data_loader.batch_size+0.5):\n",
    "#                 break\n",
    "        if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "            continue\n",
    "            \n",
    "        images=batch['images']\n",
    "        labels=batch['labels']\n",
    "        paths=batch['path']\n",
    "        updated_signal_list=[]\n",
    "        \n",
    "        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "            # Get classifier output\n",
    "            classifier_dict[backbone_type].eval()\n",
    "            with torch.no_grad():\n",
    "                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),\n",
    "                                                                 output_attentions=True)        \n",
    "            for explanation_method in [\"kernelshapnopair\"]:\n",
    "                data_keys=estimationerror_save_dict[backbone_type][explanation_method].keys()\n",
    "                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]\n",
    "                if all([(path in data_keys) or (path not in estimationerror_sample_path_list) for path in paths]):\n",
    "                    continue\n",
    "                else:\n",
    "                    print(explanation_method,'not exist')\n",
    "                    updated_signal_list.append(explanation_method)\n",
    "                    \n",
    "                if explanation_method==\"ours\":\n",
    "                    start_time=time.time()\n",
    "                    explainer_dict[backbone_type].eval()\n",
    "                    with torch.no_grad():\n",
    "                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)\n",
    "                    elapsed_time=time.time()-start_time\n",
    "                    estimationerror_save_dict_update(backbone_type, \n",
    "                                                     'ours', \n",
    "                                                     path_list=[path for path, paths in zip(paths, paths) if path in estimationerror_sample_path_list],\n",
    "                                                     estimation_list=[estimation for path, estimation in zip(paths, explanation_ours) if path in estimationerror_sample_path_list],\n",
    "                                                     label_list=[label for path, label in zip(paths, labels.cpu().numpy().tolist()) if path in estimationerror_sample_path_list],\n",
    "                                                     shape=(_config[\"output_dim\"], 196))\n",
    "                    \n",
    "                elif explanation_method==\"kernelshap\":\n",
    "                    path_list=[]\n",
    "                    explanation_kernelshap_list=[]\n",
    "                    label_list=[]\n",
    "                    \n",
    "                    surrogate_SHAP_wrapped_dict[backbone_type].eval()\n",
    "                    for path, image, label in zip(paths, images, labels.cpu().numpy().tolist()):\n",
    "                        if path in estimationerror_sample_path_list:\n",
    "                            start_time=time.time()      \n",
    "                            game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped_dict[backbone_type],\n",
    "                                                                         image)\n",
    "                            \n",
    "                            explanation_kernelshap = shapley.ShapleyRegression(game, \n",
    "                                                                    batch_size=64, \n",
    "                                                                    thresh=0.1,\n",
    "                                                                    variance_batches=60,\n",
    "                                                                    return_all=True)                              \n",
    "                            \n",
    "                            path_list.append(path)\n",
    "                            explanation_kernelshap_list.append(explanation_kernelshap)\n",
    "                            label_list.append(label)\n",
    "                            \n",
    "                    estimationerror_save_dict_update(backbone_type, \n",
    "                                                     'kernelshap', \n",
    "                                                     path_list=path_list, \n",
    "                                                     estimation_list=explanation_kernelshap_list, \n",
    "                                                     label_list=label_list,\n",
    "                                                     shape=None)  \n",
    "                    \n",
    "                elif explanation_method==\"kernelshapnopair\":\n",
    "                    path_list=[]\n",
    "                    explanation_kernelshap_list=[]\n",
    "                    label_list=[]\n",
    "                    \n",
    "                    surrogate_SHAP_wrapped_dict[backbone_type].eval()\n",
    "                    for path, image, label in zip(paths, images, labels.cpu().numpy().tolist()):\n",
    "                        if path in estimationerror_sample_path_list:\n",
    "                            start_time=time.time()      \n",
    "                            game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped_dict[backbone_type],\n",
    "                                                                         image)\n",
    "                            \n",
    "                            explanation_kernelshap = shapley.ShapleyRegression(game, \n",
    "                                                                    batch_size=128, \n",
    "                                                                    detect_convergence=False,\n",
    "                                                                    paired_sampling=False,\n",
    "                                                                    n_samples=200000,\n",
    "                                                                    variance_batches=60,\n",
    "                                                                    return_all=True)                              \n",
    "                            \n",
    "                            path_list.append(path)\n",
    "                            explanation_kernelshap_list.append(explanation_kernelshap)\n",
    "                            label_list.append(label)\n",
    "                            \n",
    "                    estimationerror_save_dict_update(backbone_type, \n",
    "                                                     'kernelshapnopair', \n",
    "                                                     path_list=path_list, \n",
    "                                                     estimation_list=explanation_kernelshap_list, \n",
    "                                                     label_list=label_list,\n",
    "                                                     shape=None)  \n",
    "                    \n",
    "                else:\n",
    "                    raise\n",
    "        \n",
    "        try:\n",
    "            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "                for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():\n",
    "\n",
    "                    if explanation_method not in updated_signal_list:\n",
    "                        continue\n",
    "\n",
    "                    estimationerror_save_dict_path=f'results/9_estimationerror/{_config[\"datasets\"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'\n",
    "\n",
    "                    if os.path.isfile(estimationerror_save_dict_path):\n",
    "                        try:\n",
    "                            with open(estimationerror_save_dict_path, 'rb') as f:\n",
    "                                estimationerror_save_dict_loaded=pickle.load(f)\n",
    "                        except:\n",
    "                            estimationerror_save_dict_loaded={}\n",
    "                    else:\n",
    "                        estimationerror_save_dict_loaded={}\n",
    "\n",
    "                    len_original=len(estimationerror_save_dict_backbone_method)            \n",
    "                    len_loaded=len(estimationerror_save_dict_loaded)\n",
    "                    estimationerror_save_dict_backbone_method.update(estimationerror_save_dict_loaded)\n",
    "                    len_updated=len(estimationerror_save_dict_backbone_method)\n",
    "\n",
    "                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')\n",
    "                    \n",
    "                    with open(estimationerror_save_dict_path, \"wb\") as f:\n",
    "                        pickle.dump(estimationerror_save_dict_backbone_method, f)\n",
    "        except:\n",
    "            pass                    \n",
    "            \n",
    "    #         # Get ours\n",
    "    #         start_time=time.time();explainer_dict[backbone_type].eval()\n",
    "    #         with torch.no_grad():\n",
    "    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))\n",
    "    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)\n",
    "    #         values_ours=(values, time.time()-start_time)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46c14574",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for path, data in estimationerror_save_dict[backbone_type][\"kernelshap\"].items():\n",
    "        explanation_save_dict[backbone_type]['kernelshap'][path]={\"explanation\": data['estimation'][0].values.T,\n",
    "                                                                  \"elapsed_time\": np.nan}\n",
    "        print(path, data['estimation'][0].values.T.shape)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02f1edf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():\n",
    "        print(backbone_type, explanation_method, len(estimationerror_save_dict_backbone_method))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1639263",
   "metadata": {},
   "outputs": [],
   "source": [
    "[elapsed_time/len(paths) for i in range(len(paths))][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f4297a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "elapsedtime_save_dict_update??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1da58468",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_masked_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "529f6928",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a86f61e3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c839705",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.session.delete();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b51e3403",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f4b2d9d0",
   "metadata": {},
   "source": [
    "# sensitivity-n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04d6d44f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_players=196\n",
    "for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split==\"test\" else None)):   \n",
    "    if dataset_split==\"test\":\n",
    "        if idx>int(1000/data_loader.batch_size+0.5):\n",
    "            break\n",
    "    if (idx%parallel_mode[1])!=parallel_mode[0]:\n",
    "        continue                                \n",
    "\n",
    "    images=batch['images']\n",
    "    labels=batch['labels']\n",
    "    paths=batch['path']\n",
    "\n",
    "    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):            \n",
    "        for image, path, label in zip(images, paths, labels):\n",
    "            #path=path.replace('l0.cs.hostname','l2lambda.cs.hostname')\n",
    "            for num_included_players in [\"all\"]:\n",
    "                print(num_included_players)                \n",
    "                if num_included_players==\"all\":\n",
    "                    prob_all=[]\n",
    "                    mask_all=[]\n",
    "                    for random_iter in range(20):\n",
    "                        image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)\n",
    "                        mask=generate_mask(num_players=num_players,\n",
    "                                           num_mask_samples=50,\n",
    "                                           paired_mask_samples=False,\n",
    "                                           mode=\"uniform\",\n",
    "                                           random_state=np.random.RandomState(random_iter))\n",
    "\n",
    "                        surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "                        with torch.no_grad():\n",
    "                            output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                                  masks=torch.Tensor(mask).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "\n",
    "                        if _config[\"output_dim\"]==1:\n",
    "                            prob=output['logits'].sigmoid().cpu().numpy()\n",
    "                        else:\n",
    "                            prob=output['logits'].softmax(dim=-1).cpu().numpy()   \n",
    "                        prob_all.append(prob)\n",
    "                        mask_all.append(mask)\n",
    "                    prob_all=np.concatenate(prob_all, axis=0)\n",
    "                    mask_all=np.concatenate(mask_all, axis=0)\n",
    "\n",
    "                for explanation_method in explanation_method_to_run:                \n",
    "                    print(explanation_method)\n",
    "                    explanation=explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys())[0])]['explanation']\n",
    "                    explanation=explanation+np.random.RandomState(42).uniform(low=0, high=1e-40, size=explanation.shape)\n",
    "                    explanation_mask=explanation@(mask_all.T)\n",
    "\n",
    "                    if len(explanation.shape)==1:\n",
    "                        correlation = np.array([stats.spearmanr(explanation_mask, prob_all[:, i]).correlation for i in range(prob_all.shape[1])])\n",
    "                        assert correlation.shape==(_config[\"output_dim\"],)\n",
    "\n",
    "                    elif len(explanation.shape)==2:\n",
    "                        #correlation=stats.spearmanr(np.concatenate([explanation_mask, prob_all.T], axis=0), axis=1).correlation\n",
    "                        correlation = np.array([stats.spearmanr(explanation_mask[i], prob_all[:, i]).correlation for i in range(prob_all.shape[1])])\n",
    "                        assert correlation.shape==(_config[\"output_dim\"],)\n",
    "                        \n",
    "                        fig=plt.figure(figsize=(20,5))\n",
    "                        ax=fig.add_subplot(121)\n",
    "                        ax.scatter(explanation_mask[label], prob_all[:,label])#\n",
    "                        ax.set_xlabel(\"sum_explanation\")\n",
    "                        ax.set_ylabel(\"model output\")\n",
    "                        ax.set_title(explanation_method)\n",
    "                        \n",
    "                        ax=fig.add_subplot(122)\n",
    "                        for i in range(prob_all.shape[1]):\n",
    "                            if i>3:\n",
    "                                break\n",
    "                            if i!=label:\n",
    "                                ax.scatter(explanation_mask[i], prob_all[:,i])#, s=4)#\n",
    "                        ax.set_xlabel(\"sum_explanation\")\n",
    "                        ax.set_ylabel(\"model output\")\n",
    "                        ax.set_title(explanation_method)                        \n",
    "                        \n",
    "                        plt.show()\n",
    "                        \n",
    "                    else:\n",
    "                        raise\n",
    "#                     sensitivity_save_dit_update(backbone_type, explanation_method,\n",
    "#                                                  num_included_players=num_included_players,\n",
    "#                                                  path_list=[path], sensitivity_list=[correlation],\n",
    "#                                                  shape=(_config[\"output_dim\"],))     \n",
    "            break\n",
    "\n",
    "        break\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1b35df3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig=plt.figure()\n",
    "ax=fig.add_subplot()\n",
    "ax.scatter(explanation_mask[label], prob_all[:,label])#\n",
    "ax.set_xlabel(\"sum_explanation\")\n",
    "ax.set_ylabel(\"model output\")\n",
    "ax.set_title(explanation_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fcd97c6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "eb3e1c7e",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a068a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import seaborn as sns\n",
    "fig = plt.figure()\n",
    "ax_temp=fig.add_subplot()\n",
    "plt.clf()\n",
    "\n",
    "def visualize_result(x, values, pred=None, vmin_vmax='separate',\n",
    "                     image_labels=['normal','abnormal']*5,\n",
    "                     class_labels=['normal','abnormal']*5):\n",
    "    # colormap\n",
    "    from matplotlib import cm\n",
    "    from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n",
    "\n",
    "    color_num = 1000\n",
    "    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]\n",
    "    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]    \n",
    "\n",
    "    if isinstance(vmin_vmax,tuple):\n",
    "        vmin, vmax= vmin_vmax\n",
    "        assert vmin < vmax\n",
    "        if vmin * vmax < 0:\n",
    "            ratio=vmax/(-vmin+vmax)\n",
    "            seismic = cm.get_cmap('seismic', color_num)\n",
    "            newcolors = seismic(np.linspace(0, 1, color_num))\n",
    "            newcolors[:int(color_num*(1-ratio))] = seismic(np.linspace(0, 0.5, int(color_num*(1-ratio))))\n",
    "            newcolors[-int(color_num*ratio)-1:] = seismic(np.linspace(0.5, 1, int(color_num*ratio)+1))\n",
    "            newcmp = ListedColormap(newcolors)\n",
    "        elif vmin > 0:\n",
    "            seismic = cm.get_cmap('seismic', color_num)\n",
    "            newcolors = seismic(np.linspace(0.5, 1, color_num))\n",
    "            newcmp = ListedColormap(newcolors)\n",
    "        else:\n",
    "            raise\n",
    "    elif vmin_vmax==\"separate\":\n",
    "        #seismic = cm.get_cmap('seismic', color_num)\n",
    "        #newcolors = seismic(np.linspace(0.5, 1, color_num))\n",
    "        #newcmp = ListedColormap(newcolors)        \n",
    "        if values.min()>0:\n",
    "            seismic = cm.get_cmap('seismic', color_num)\n",
    "            newcolors = seismic(np.linspace(0.5, 1, color_num))\n",
    "            newcmp = ListedColormap(newcolors)            \n",
    "        else:\n",
    "            seismic = cm.get_cmap('seismic', color_num)\n",
    "            newcolors = seismic(np.linspace(0, 1, color_num))\n",
    "            newcmp = ListedColormap(newcolors)\n",
    "\n",
    "\n",
    "    fig, axes = plt.subplots(values.shape[0], 1+values.shape[1], figsize=(2*(1+values.shape[1]), 2*(values.shape[0]+1)))\n",
    "\n",
    "    assert len(image_labels)==values.shape[0]==(len(image_labels) if pred is None else pred.shape[0])\n",
    "    assert len(class_labels)==values.shape[1]==(len(class_labels) if pred is None else pred.shape[1])\n",
    "    \n",
    "    for row in range(axes.shape[0]):\n",
    "        for col in range(axes.shape[1]):\n",
    "            if col==0: # Image\n",
    "                im = x[row].numpy() * img_std + img_mean # (C, H, W)\n",
    "                im = im.transpose(1, 2, 0).astype(float) # (H, W, C)\n",
    "                im = np.clip(im, a_min=0, a_max=1)\n",
    "\n",
    "                axes[row, 0].imshow(im, vmin=0, vmax=1)\n",
    "                axes[row, 0].set_ylabel('{}'.format(image_labels[row]), fontsize=12)\n",
    "            else: # Explanation\n",
    "                values_select=values[row, col-1]\n",
    "                values_select_min, values_select_max=values_select.min(),values_select.max()\n",
    "\n",
    "                if vmin_vmax==\"separate\":\n",
    "                    if values.min()>0:\n",
    "                        axes[row, col].imshow(values_select, cmap=newcmp,\n",
    "                                              vmin=values_select_min,\n",
    "                                              vmax=values_select_max)\n",
    "                    else:\n",
    "                        axes[row, col].imshow(values_select, cmap=newcmp, \n",
    "                                              vmin=-max([abs(values_select_min),abs(values_select_max)]), \n",
    "                                              vmax=max([abs(values_select_min),abs(values_select_max)]))\n",
    "                else:\n",
    "                    axes[row, col].imshow(values_select, cmap=newcmp, vmin=vmin, vmax=vmax)\n",
    "\n",
    "                if pred is None:\n",
    "                    axes[row, col].set_xlabel('{:.2f}/{:.2f}'.format(values_select_min, values_select_max), fontsize=12)\n",
    "                else:\n",
    "                    axes[row, col].set_xlabel('{:.2f} {:.2f}/{:.2f}'.format(pred[row, col-1], values_select_min, values_select_max), fontsize=12)            \n",
    "\n",
    "                # Class labels\n",
    "                if row == 0:\n",
    "                    axes[row, col].set_title('{}'.format(class_labels[col-1]), fontsize=12)  \n",
    "\n",
    "            axes[row, col].set_xticks([])\n",
    "            axes[row, col].set_yticks([])                           \n",
    "\n",
    "    if vmin_vmax!=\"separate\":\n",
    "        fig = plt.figure(figsize=(5, 0.5))\n",
    "        ax=fig.add_subplot()        \n",
    "        \n",
    "        sns.heatmap([[0,0],[0,0]],\n",
    "                    ax=ax_temp,\n",
    "                    cmap=newcmp,\n",
    "                    vmin=vmin,\n",
    "                    vmax=vmax,\n",
    "                    xticklabels=True,\n",
    "                    linewidths=1,\n",
    "                    linecolor=np.array([220,220,220,256])/256,\n",
    "                    cbar_ax=ax,\n",
    "                    cbar_kws={'fraction':0.1, \"ticks\":np.linspace(vmin, vmax, 5), \"orientation\": \"horizontal\"},\n",
    "                    cbar=True,\n",
    "                    alpha=1,edgecolor='black')#,legend=None)\n",
    "\n",
    "        #plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "        fig=plt.figure()\n",
    "        ax=fig.add_subplot()\n",
    "        ax.hist(values.flatten(),bins=np.linspace(vmin, vmax, 20))\n",
    "        ax.set_yscale('log')\n",
    "        print(values.flatten().min(), values.flatten().max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03b53e07",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prob_all_all=[]\n",
    "mask_all_all=[]\n",
    "for idx, image in enumerate(x):\n",
    "    print(idx)\n",
    "    prob_all=[]\n",
    "    mask_all=[]\n",
    "    for random_iter in range(20):\n",
    "        image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type][\"pre-softmax\"].device)\n",
    "        mask=generate_mask(num_players=num_players,\n",
    "                           num_mask_samples=50,\n",
    "                           paired_mask_samples=False,\n",
    "                           mode=\"uniform\",\n",
    "                           random_state=np.random.RandomState(random_iter))\n",
    "\n",
    "        surrogate_dict[backbone_type][\"pre-softmax\"].eval()\n",
    "        with torch.no_grad():\n",
    "            output = surrogate_dict[backbone_type][\"pre-softmax\"](image_loaded,\n",
    "                                                                  masks=torch.Tensor(mask).to(surrogate_dict[backbone_type][\"pre-softmax\"].device))\n",
    "\n",
    "        if _config[\"output_dim\"]==1:\n",
    "            prob=output['logits'].sigmoid().cpu().numpy()\n",
    "        else:\n",
    "            prob=output['logits'].softmax(dim=-1).cpu().numpy()   \n",
    "        prob_all.append(prob)\n",
    "        mask_all.append(mask)\n",
    "    prob_all=np.concatenate(prob_all, axis=0)\n",
    "    mask_all=np.concatenate(mask_all, axis=0)\n",
    "    #print(prob_all.shape, mask_all.shape)\n",
    "    prob_all_all.append(prob_all)\n",
    "    mask_all_all.append(mask_all)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0270ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "values.reshape(values.shape[0], values.shape[1], 196)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd15b98b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    #get classifier\n",
    "    classifier_dict[backbone_type].eval()\n",
    "    with torch.no_grad():\n",
    "        output = classifier_dict[backbone_type](x.to(next(classifier_dict[backbone_type].parameters()).device), output_attentions=False)\n",
    "        if _config[\"output_dim\"]==1:\n",
    "            pred=output['logits'].detach().sigmoid().cpu().data.numpy()\n",
    "        else:\n",
    "            pred=output['logits'].detach().softmax().cpu().data.numpy()\n",
    "    del output    \n",
    "    \n",
    "    # Get explanation (modified_LRP)\n",
    "    values=np.concatenate([np.concatenate([get_lrp_module_explanation(backbone_type, image.squeeze(0), class_index=i, mode='transformer_attribution').cpu() for i in range(_config[\"output_dim\"])], axis=0)[np.newaxis,:] for image in x])\n",
    "    values=values.reshape(values.shape[0], values.shape[1], 14, 14)\n",
    "    \n",
    "    visualize_result(x, pred=pred, values=values,\n",
    "                     class_labels=y_labels[1:2] if _config[\"output_dim\"]==1 else y_labels,\n",
    "                     image_labels=y_labels,\n",
    "                     vmin_vmax=\"separate\")\n",
    "    \n",
    "    \n",
    "    \n",
    "    values_lrp=values.reshape(values.shape[0], values.shape[1], 196)\n",
    "    \n",
    "    \n",
    "    explainer_dict[backbone_type].eval()\n",
    "    with torch.no_grad():\n",
    "        values=explainer_dict[backbone_type](x.to(next(explainer_dict[backbone_type].parameters()).device))    \n",
    "        values=values[0].reshape(-1, _config[\"output_dim\"], 14, 14).cpu().numpy()\n",
    "        \n",
    "    visualize_result(x, pred=pred, values=values,\n",
    "                     class_labels=y_labels[1:2] if _config[\"output_dim\"]==1 else y_labels,\n",
    "                     image_labels=y_labels,\n",
    "                     vmin_vmax=(-0.2, 0.2))    \n",
    "    \n",
    "    values_ours=values.reshape(values.shape[0], values.shape[1], 196)\n",
    "    \n",
    "    #visualize_result(x, pred=pred.repeat(10, axis=1), values=values.repeat(10, axis=1), separate=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vitshapley",
   "language": "python",
   "name": "vitshapley"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
