{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2b96013",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.getcwd())\n",
    "os.chdir('../')\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "190e98ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from collections import OrderedDict\n",
    "import copy\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from vitshapley.datamodules.ImageNette_datamodule import ImageNetteDataModule\n",
    "\n",
    "from vitshapley.modules.classifier import Classifier\n",
    "from vitshapley.modules.surrogate import Surrogate\n",
    "from vitshapley.modules.explainer import Explainer\n",
    "\n",
    "from vitshapley.config import ex\n",
    "from vitshapley.config import config, env_username, dataset_ImageNette\n",
    "\n",
    "dataset_split=\"test\"\n",
    "backbone_to_use=[\"vit_base_patch16_224\"]\n",
    "\n",
    "\n",
    "_config=config()\n",
    "_config.update(env_username()); _config.update({'gpus_classifier':[0,],\n",
    "                                                'gpus_surrogate':[0,],\n",
    "                                                'gpus_explainer':[0,]})\n",
    "_config.update(dataset_ImageNette())\n",
    "_config.update({'classifier_backbone_type': None,\n",
    "                'classifier_download_weight': False,\n",
    "                'classifier_load_path': None})\n",
    "_config.update({'surrogate_mask_location': \"pre-softmax\"})\n",
    "_config.update({'surrogate_backbone_type': None,\n",
    "                'surrogate_download_weight': False,\n",
    "                'surrogate_load_path': None})\n",
    "_config.update({'explainer_num_mask_samples': 2,\n",
    "                'explainer_paired_mask_samples': True})\n",
    "print('done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c99b9071",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from collections import OrderedDict\n",
    "import copy\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb20155b",
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone_type_config_dict_=OrderedDict({\n",
    "    \"vit_small_patch16_224\":{\n",
    "        \"surrogate_path\":{\"original\": \"results/transformer_interpretability/17inn4ht/checkpoints/epoch=14-step=2204.ckpt\",\n",
    "                          \"pre-softmax\": \"results/transformer_interpretability/3kv2ns41/checkpoints/epoch=29-step=4409.ckpt\",\n",
    "                          \"post-softmax\": \"results/transformer_interpretability/31as48v7/checkpoints/epoch=32-step=4850.ckpt\",\n",
    "                          \"zero-input\": \"results/transformer_interpretability/j8sihn8t/checkpoints/epoch=33-step=4997.ckpt\"},\n",
    "    },\n",
    "    \"deit_small_patch16_224\":{\n",
    "        \"surrogate_path\": {},\n",
    "    },\n",
    "    \"vit_base_patch16_224\":{\n",
    "        \"surrogate_path\": {\"original\": \"results/transformer_interpretability/3f67z73f/checkpoints/epoch=11-step=1763.ckpt\",\n",
    "                           \"pre-softmax\": \"results/transformer_interpretability/zeydyraj/checkpoints/epoch=15-step=2351.ckpt\",\n",
    "                           \"post-softmax\": \"results/transformer_interpretability/1ijt5xox/checkpoints/epoch=33-step=4997.ckpt\",\n",
    "                           \"zero-input\": \"results/transformer_interpretability/1w1sgm9q/checkpoints/epoch=15-step=2351.ckpt\"\n",
    "                          },\n",
    "    },\n",
    "    \"deit_base_patch16_224\":{\n",
    "        \"surrogate_path\": {},\n",
    "    }\n",
    "})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87361649",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,\n",
    "                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        num_players: the number of players in the coalitional game\n",
    "        num_mask_samples: the number of masks to generate\n",
    "        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.\n",
    "        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')\n",
    "        random_state: random generator\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor of shape\n",
    "        (num_masks, num_players) if num_masks is int\n",
    "        (num_players) if num_masks is None\n",
    "\n",
    "    \"\"\"\n",
    "    random_state = random_state or np.random\n",
    "\n",
    "    num_samples_ = num_mask_samples or 1\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        assert num_samples_ % 2 == 0, \"'num_samples' must be a multiple of 2 if 'paired' is True\"\n",
    "        num_samples_ = num_samples_ // 2\n",
    "    else:\n",
    "        num_samples_ = num_samples_\n",
    "\n",
    "    if mode == 'uniform':\n",
    "        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')\n",
    "    elif mode == 'shapley':\n",
    "        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))\n",
    "        probs = probs / probs.sum()\n",
    "        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(\n",
    "            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')\n",
    "    else:\n",
    "        raise ValueError(\"'mode' must be 'random' or 'shapley'\")\n",
    "\n",
    "    if paired_mask_samples:\n",
    "        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)\n",
    "\n",
    "    if num_mask_samples is None:\n",
    "        masks = masks.squeeze(0)\n",
    "        return masks  # (num_masks)\n",
    "    else:\n",
    "        return masks  # (num_samples, num_masks)\n",
    "\n",
    "\n",
    "def set_datamodule(datasets,\n",
    "                   dataset_location,\n",
    "                   transforms_train,\n",
    "                   transforms_val,\n",
    "                   transforms_test,\n",
    "                   num_workers,\n",
    "                   per_gpu_batch_size,\n",
    "                   test_data_split):\n",
    "\n",
    "    dataset_parameters = {\n",
    "        \"dataset_location\": dataset_location,\n",
    "        \"transforms_train\": transforms_train,\n",
    "        \"transforms_val\": transforms_val,\n",
    "        \"transforms_test\": transforms_test,\n",
    "        \"num_workers\": num_workers,\n",
    "        \"per_gpu_batch_size\": per_gpu_batch_size,\n",
    "        \"test_data_split\": test_data_split\n",
    "    }\n",
    "\n",
    "    if datasets == \"APTOS2019\":\n",
    "        datamodule = APTOS2019DataModule(**dataset_parameters)\n",
    "    elif datasets == \"CheXpert\":\n",
    "        datamodule = CheXpertDataModule(**dataset_parameters)\n",
    "    elif datasets == \"MIMIC\":\n",
    "        datamodule = MIMICDataModule(**dataset_parameters)\n",
    "    elif datasets == \"CIFAR\":\n",
    "        datamodule = CIFARDataModule(**dataset_parameters)\n",
    "    elif datasets == \"ImageNette\":\n",
    "        datamodule = ImageNetteDataModule(**dataset_parameters)\n",
    "    else:\n",
    "        ValueError(\"Invalid 'datasets' configuration\")\n",
    "    return datamodule\n",
    "\n",
    "datamodule = set_datamodule(datasets=_config[\"datasets\"],\n",
    "                            dataset_location=_config[\"dataset_location\"],\n",
    "                            transforms_train=_config[\"transforms_train\"],\n",
    "                            transforms_val=_config[\"transforms_val\"],\n",
    "                            transforms_test=_config[\"transforms_test\"],\n",
    "                            num_workers=_config[\"num_workers\"],\n",
    "                            per_gpu_batch_size=_config[\"per_gpu_batch_size\"],\n",
    "                            test_data_split=_config[\"test_data_split\"])\n",
    "\n",
    "# The batch for training classifier consists of images and labels, but the batch for training explainer consists of images and masks.\n",
    "# The masks are generated to follow the Shapley distribution.\n",
    "\"\"\"\n",
    "original_getitem = copy.deepcopy(datamodule.dataset_cls.__getitem__)\n",
    "def __getitem__(self, idx):\n",
    "    if self.split == 'train':\n",
    "        masks = generate_mask(num_players=surrogate.num_players,\n",
    "                              num_mask_samples=_config[\"explainer_num_mask_samples\"],\n",
    "                              paired_mask_samples=_config[\"explainer_paired_mask_samples\"], mode='shapley')\n",
    "    elif self.split == 'val' or self.split == 'test':\n",
    "        # get cached if available\n",
    "        if not hasattr(self, \"masks_cached\"):\n",
    "            self.masks_cached = {}\n",
    "        masks = self.masks_cached.setdefault(idx, generate_mask(num_players=surrogate.num_players,\n",
    "                                                                num_mask_samples=_config[\n",
    "                                                                    \"explainer_num_mask_samples\"],\n",
    "                                                                paired_mask_samples=_config[\n",
    "                                                                    \"explainer_paired_mask_samples\"],\n",
    "                                                                mode='shapley'))\n",
    "    else:\n",
    "        raise ValueError(\"'split' variable must be train, val or test.\")\n",
    "    return {\"images\": original_getitem(self, idx)[\"images\"],\n",
    "            \"labels\": original_getitem(self, idx)[\"labels\"],\n",
    "            \"masks\": masks}\n",
    "datamodule.dataset_cls.__getitem__ = __getitem__\n",
    "\"\"\"\n",
    "\n",
    "datamodule.set_train_dataset()\n",
    "datamodule.set_val_dataset()\n",
    "datamodule.set_test_dataset()\n",
    "\n",
    "train_dataset=datamodule.train_dataset\n",
    "val_dataset=datamodule.val_dataset\n",
    "test_dataset=datamodule.test_dataset\n",
    "\n",
    "classidx=4\n",
    "\n",
    "if dataset_split==\"train\":\n",
    "    dset = train_dataset \n",
    "elif dataset_split==\"val\":\n",
    "    dset = val_dataset     \n",
    "elif dataset_split==\"test\": \n",
    "    dset = test_dataset\n",
    "else:\n",
    "    raise\n",
    "\n",
    "labels = np.array([i['label'] for i in dset.data])\n",
    "num_classes = labels.max() + 1\n",
    "\n",
    "images_idx_list = [np.where(labels == category)[0] for category in range(num_classes)]\n",
    "images_idx = [category_idx[classidx] for category_idx in images_idx_list]\n",
    "\n",
    "xy=[dset[idx] for idx in images_idx]\n",
    "x, y = zip(*[(i['images'], i['labels']) for i in xy])\n",
    "x = torch.stack(x)\n",
    "len(dset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "521ba68f",
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone_type_config_dict = OrderedDict()\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict_.items()):\n",
    "    if backbone_type in backbone_to_use:\n",
    "        print(backbone_type)\n",
    "        backbone_type_config_dict[backbone_type]=backbone_type_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8794a43",
   "metadata": {},
   "outputs": [],
   "source": [
    "surrogate_dict = OrderedDict()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    mask_method_dict = OrderedDict()\n",
    "    for mask_location in [\"original\", \"pre-softmax\", \"post-softmax\", \"zero-input\"]:\n",
    "        mask_method_dict[mask_location] = Surrogate(mask_location=mask_location if mask_location!=\"original\" else \"pre-softmax\",\n",
    "                                          backbone_type=backbone_type,\n",
    "                                          download_weight=_config['surrogate_download_weight'],\n",
    "                                          load_path=backbone_type_config[\"surrogate_path\"][mask_location],\n",
    "                                          target_type=_config[\"target_type\"],\n",
    "                                          output_dim=_config[\"output_dim\"],\n",
    "\n",
    "                                          target_model=None,\n",
    "                                          checkpoint_metric=None,\n",
    "                                          optim_type=None,\n",
    "                                          learning_rate=None,\n",
    "                                          weight_decay=None,\n",
    "                                          decay_power=None,\n",
    "                                          warmup_steps=None).to(_config[\"gpus_surrogate\"][idx])\n",
    "    surrogate_dict[backbone_type]=mask_method_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc939d68",
   "metadata": {},
   "outputs": [],
   "source": [
    "dset_loader=DataLoader(dset, batch_size=64, num_workers=4, shuffle=False, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f42b38e1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import copy\n",
    "\n",
    "mask = (torch.rand(1, 196)>0.5).int()\n",
    "\n",
    "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n",
    "    surrogate = surrogate_dict[backbone_type]['pre-softmax']\n",
    "    surrogate_ = copy.deepcopy(surrogate)\n",
    "    for batch_idx, batch in enumerate(tqdm(dset_loader, unit='batch')):  \n",
    "        with torch.no_grad():\n",
    "            logits = surrogate(batch[\"images\"].to(surrogate.device),\n",
    "                               torch.repeat_interleave(mask, len(batch[\"images\"]), dim=0).to(surrogate.device))['logits']\n",
    "        \n",
    "            image_patchified=surrogate.backbone.patch_embed(batch[\"images\"].to(surrogate.device))\n",
    "            surrogate_.backbone.pos_embed=torch.nn.Parameter(torch.concat([surrogate.backbone.pos_embed[:,0:1],\n",
    "                                                                           surrogate.backbone.pos_embed[:,1:][:, mask[0]==1]], dim=1))\n",
    "            \n",
    "            image_patchified_attention=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], \n",
    "                                                                            torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax')\n",
    "            logits_held_out = surrogate.head(image_patchified_attention['x'])\n",
    "            \n",
    "            images_perturbed=copy.deepcopy(batch[\"images\"])\n",
    "            images_perturbed[torch.repeat_interleave(torch.repeat_interleave(torch.repeat_interleave(torch.repeat_interleave(mask.reshape(1, 1, 14, 14), 16, dim=2), 16, dim=3), 64, dim=0), 3, dim=1)==0]=4242\n",
    "            logits_perturbed = surrogate(images_perturbed.to(surrogate.device),\n",
    "                               torch.repeat_interleave(mask, len(batch[\"images\"]), dim=0).to(surrogate.device))['logits']            \n",
    "            \n",
    "            \n",
    "            assert torch.isclose(logits, logits_held_out, rtol=1e-2).all()\n",
    "            assert torch.isclose(logits, logits_perturbed, rtol=1e-2).all()\n",
    "        \n",
    "            \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e656aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    b=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], \n",
    "    torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax')['x']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84461d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    a=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], \n",
    "    torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax_')['x']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90a7c804",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "(a==b).all()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vitshapley",
   "language": "python",
   "name": "vitshapley"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
