{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8141aca4915caa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import warnings\n",
    "from itertools import chain\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.transforms import v2\n",
    "\n",
    "from datasets.funny_birds import FunnyBirds\n",
    "from models.ViT.ViT_new import vit_base_patch16_224\n",
    "from models.resnet import resnet50\n",
    "from models.resnet_own import own_resnet50\n",
    "from models.vgg import vgg16\n",
    "from models.vision_transformer import own_vit_b_16\n",
    "from train import AverageMeter, Summary, ProgressMeter, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cc466f13d58024",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description=\"PyTorch ImageNet Training\")\n",
    "parser.add_argument(\n",
    "    \"--data\", metavar=\"DIR\", required=True, help=\"path to dataset (default: imagenet)\"\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--model\",\n",
    "    required=True,\n",
    "    choices=[\"resnet50\", \"vgg16\", \"vit_b_16\", \"own_resnet50\", \"own_vit_b_16\"],\n",
    "    help=\"model architecture\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--checkpoint_dir\",\n",
    "    metavar=\"DIR\",\n",
    "    required=True,\n",
    "    default=None,\n",
    "    help=\"path to checkpoints\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--checkpoint_prefix\",\n",
    "    type=str,\n",
    "    required=True,\n",
    "    default=None,\n",
    "    help=\"checkpoint prefix\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--epochs\", default=120, type=int, metavar=\"N\", help=\"number of total epochs to run\"\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--step_size\",\n",
    "    default=60,\n",
    "    type=int,\n",
    "    metavar=\"N\",\n",
    "    help=\"number of total epochs to run\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"-b\",\n",
    "    \"--batch-size\",\n",
    "    default=64,\n",
    "    type=int,\n",
    "    metavar=\"N\",\n",
    "    help=\"mini-batch size (default: 64), this is the total \"\n",
    "    \"batch size of all GPUs on the current node when \"\n",
    "    \"using Data Parallel or Distributed Data Parallel\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--lr\",\n",
    "    \"--learning-rate\",\n",
    "    default=0.1,\n",
    "    type=float,\n",
    "    metavar=\"LR\",\n",
    "    help=\"initial learning rate\",\n",
    "    dest=\"lr\",\n",
    ")\n",
    "parser.add_argument(\"--momentum\", default=0.9, type=float, metavar=\"M\", help=\"momentum\")\n",
    "parser.add_argument(\n",
    "    \"--wd\",\n",
    "    \"--weight-decay\",\n",
    "    default=1e-4,\n",
    "    type=float,\n",
    "    metavar=\"W\",\n",
    "    help=\"weight decay (default: 1e-4)\",\n",
    "    dest=\"weight_decay\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"-p\",\n",
    "    \"--print-freq\",\n",
    "    default=10,\n",
    "    type=int,\n",
    "    metavar=\"N\",\n",
    "    help=\"print frequency (default: 10)\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--pretrained\", dest=\"pretrained\", action=\"store_true\", help=\"use pre-trained model\"\n",
    ")\n",
    "parser.add_argument(\"--pretrained_ckpt\", type=str)\n",
    "parser.add_argument(\"--multi_target\", action=\"store_true\", help=\"use pre-trained model\")\n",
    "parser.add_argument(\n",
    "    \"--seed\", default=0, type=int, help=\"seed for initializing training. \"\n",
    ")\n",
    "parser.add_argument(\"--gpu\", default=0, type=int, help=\"GPU id to use.\")\n",
    "# -----------------------------------\n",
    "# Parameters for gumbel trick\n",
    "# -----------------------------------\n",
    "parser.add_argument(\"--gumbel-dim\", default=-1, type=int, choices=[1, -1])\n",
    "parser.add_argument(\n",
    "    \"--gumbel_tau\",\n",
    "    type=float,\n",
    "    nargs=2,\n",
    "    default=(1, 0.2),\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--gumbel_range\",\n",
    "    type=int,\n",
    "    nargs=2,\n",
    "    default=(20, 90),\n",
    "    help=\"Range of working gumbel trick in epoch\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--gumbel_annealing_strategy\",\n",
    "    default=\"cosine\",\n",
    "    choices=[\"linear\", \"constant\", \"exponential\", \"cosine\"],\n",
    ")\n",
    "# ----------------------\n",
    "parser.add_argument(\n",
    "    \"--finetuning\", action=\"store_true\", help=\"Fine tune part parameters of model\"\n",
    ")\n",
    "parser.add_argument(\"--resume\", default=\"\", type=str, help=\"path of checkpoint\")\n",
    "parser.add_argument(\"--img_size\", default=256, type=int, help=\"Size of image\")\n",
    "\n",
    "params = \"\"\"\n",
    "--data ?  # Set the path to the dataset\n",
    "--checkpoint_dir ?  # Set the path to checkpoints\n",
    "--checkpoint_prefix resnet50_default\n",
    "# --pretrained\n",
    "# =========================================\n",
    "# # --model resnet50\n",
    "# # --resume /results/funnybirds-models/resnet50_final_0_checkpoint_best.pth.tar\n",
    "# # \n",
    "# --model own_resnet50\n",
    "# --resume ? # Set the path to the checkpoint\n",
    "# --gumbel-dim -1\n",
    "# #\n",
    "# --img_size 256\n",
    "# =========================================\n",
    "# --model vit_b_16\n",
    "# --resume /results/funnybirds-models/vit_base_patch16_224_final_1_checkpoint_best.pth.tar\n",
    "#\n",
    "--model own_vit_b_16\n",
    "--resume ? # Set the path to the checkpoint\n",
    "--gumbel-dim 1\n",
    "#\n",
    "--img_size 224\n",
    "# =========================================\n",
    "# --finetuning\n",
    "\"\"\"\n",
    "\n",
    "args_params = list(\n",
    "    chain(\n",
    "        *map(\n",
    "            lambda x: x.split(\"#\")[0].split(),\n",
    "            filter(lambda x: x.strip() and not x.startswith(\"#\"), params.split(\"\\n\")),\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "args = parser.parse_args(args=args_params)\n",
    "\n",
    "args.gumbel_tau = tuple(sorted(args.gumbel_tau, reverse=True))\n",
    "args.gumbel_range = tuple(sorted(args.gumbel_range))\n",
    "args.img_size = (args.img_size,) * 2\n",
    "\n",
    "# Ensure --gumbel_range contains values smaller then --epochs\n",
    "if args.gumbel_range[1] > args.epochs:\n",
    "    parser.error(\"--gumbel_range must contain values smaller then --epochs\")\n",
    "\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "466b4ce86edb89fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {device}\")\n",
    "\n",
    "# create model\n",
    "if args.model == \"resnet50\":\n",
    "    model = resnet50(pretrained=args.pretrained)\n",
    "    model.fc = torch.nn.Linear(2048, 50)\n",
    "elif args.model == \"own_resnet50\":\n",
    "    model = own_resnet50(\n",
    "        pretrained=args.pretrained,\n",
    "        num_classes=50,\n",
    "        gumbel_dim=args.gumbel_dim,\n",
    "        tau=args.gumbel_tau[0],\n",
    "    )\n",
    "elif args.model == \"vgg16\":\n",
    "    model = vgg16(pretrained=args.pretrained)\n",
    "    model.classifier[-1] = torch.nn.Linear(4096, 50)\n",
    "elif args.model == \"vit_b_16\":\n",
    "    model = vit_base_patch16_224(pretrained=args.pretrained)\n",
    "    model.head = torch.nn.Linear(768, 50)\n",
    "elif args.model == \"own_vit_b_16\":\n",
    "    model = own_vit_b_16(\n",
    "        pretrained=args.pretrained,\n",
    "        num_classes=50,\n",
    "        gumbel_dim=args.gumbel_dim,\n",
    "        tau=args.gumbel_tau[0],\n",
    "    )\n",
    "else:\n",
    "    print(\"Model not implemented\")\n",
    "\n",
    "if args.resume:\n",
    "    checkpoint = torch.load(args.resume, map_location=\"cpu\", weights_only=False)\n",
    "    print(\n",
    "        f\"Model: {checkpoint['model']}, accuracy: {checkpoint['best_acc1']:.2f}, epoch: {checkpoint['epoch']}\"\n",
    "    )\n",
    "\n",
    "    missing_keys, unexpected_keys = model.load_state_dict(\n",
    "        checkpoint[\"state_dict\"], strict=False\n",
    "    )\n",
    "    warning_message = \"\"\n",
    "    if missing_keys:\n",
    "        warning_message += f\"\\033[0;1;33mMissing keys: {missing_keys}\\033[0m \"\n",
    "    if unexpected_keys:\n",
    "        warning_message += f\"\\033[0;1;36mUnexpected keys: {unexpected_keys}\\033[0m\"\n",
    "    if warning_message:\n",
    "        warnings.warn(warning_message)\n",
    "\n",
    "    print(f\"Resuming model from file '{args.resume}'.\")\n",
    "\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2365249a94d936ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.finetuning and hasattr(model, \"changed_layers\"):\n",
    "    print(\n",
    "        f\"\\033[0;1;33mFinetuning the selected layer of the model {model.changed_layers}\\033[0m\"\n",
    "    )\n",
    "    for name, param in model.named_parameters():\n",
    "        param.requires_grad = name.split(\".\")[0] in model.changed_layers\n",
    "\n",
    "    # Filter parameters with requires_grad and calculate their sum\n",
    "parameters_to_optimize = [param for param in model.parameters() if param.requires_grad]\n",
    "total_params = sum(p.numel() for p in parameters_to_optimize)\n",
    "print(\"Total trainable parameters:\", total_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2589bbda00c2374",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data loading code\n",
    "transforms = v2.Compose(\n",
    "    [\n",
    "        v2.Resize(size=args.img_size),\n",
    "    ]\n",
    ")\n",
    "\n",
    "test_dataset = FunnyBirds(args.data, \"test\", transform=transforms)\n",
    "test_loader = DataLoader(\n",
    "    test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8\n",
    ")\n",
    "\n",
    "top1 = AverageMeter(\"Acc@1\", \":6.2f\", Summary.AVERAGE)\n",
    "top5 = AverageMeter(\"Acc@5\", \":6.2f\", Summary.AVERAGE)\n",
    "progress = ProgressMeter(\n",
    "    len(test_loader)\n",
    "    + (False and (len(test_loader.sampler) * -1 < len(test_loader.dataset))),\n",
    "    [top1, top5],\n",
    "    prefix=\"Test: \",\n",
    ")\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for i, samples in enumerate(test_loader):\n",
    "        images = samples[\"image\"]\n",
    "        target = samples[\"class_idx\"]\n",
    "\n",
    "        images = images.to(device, non_blocking=True)\n",
    "        target = target.to(device, non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(images)\n",
    "\n",
    "        # measure accuracy and record loss\n",
    "        acc1, acc5 = accuracy(output, target, topk=(1, 5))\n",
    "        top1.update(acc1[0], images.size(0))\n",
    "        top5.update(acc5[0], images.size(0))\n",
    "\n",
    "        if i % args.print_freq == 0:\n",
    "            progress.display(i + 1)\n",
    "\n",
    "print(\"=\" * 50)\n",
    "progress.display_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de4f29bf661f76d5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3ed007a2f20865",
   "metadata": {},
   "source": [
    "# Evaluate explainability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e0797f912fc475",
   "metadata": {},
   "outputs": [],
   "source": [
    "from types import MethodType\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "from captum.attr import IntegratedGradients, InputXGradient\n",
    "\n",
    "from models.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP\n",
    "from models.model_wrapper import StandardModel, ViTModel\n",
    "from evaluation_protocols import (\n",
    "    accuracy_protocol,\n",
    "    controlled_synthetic_data_check_protocol,\n",
    "    single_deletion_protocol,\n",
    "    preservation_check_protocol,\n",
    "    deletion_check_protocol,\n",
    "    target_sensitivity_protocol,\n",
    "    distractibility_protocol,\n",
    "    background_independence_protocol,\n",
    ")\n",
    "from explainers.explainer_wrapper import (\n",
    "    CaptumAttributionExplainer,\n",
    "    ViTRolloutExplainer,\n",
    "    ViTCheferLRPExplainer,\n",
    "    AbstractAttributionExplainer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c4578f0d6a03e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description=\"FunnyBirds - Explanation Evaluation\")\n",
    "parser.add_argument(\n",
    "    \"--data\", metavar=\"DIR\", required=True, help=\"path to dataset (default: imagenet)\"\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--model\",\n",
    "    required=True,\n",
    "    choices=[\"resnet50\", \"vgg16\", \"vit_b_16\", \"own_resnet50\", \"own_vit_b_16\"],\n",
    "    help=\"model architecture\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--explainer\",\n",
    "    required=True,\n",
    "    choices=[\n",
    "        \"IntegratedGradients\",\n",
    "        \"InputXGradient\",\n",
    "        \"Rollout\",\n",
    "        \"CheferLRP\",\n",
    "        \"CustomExplainer\",\n",
    "    ],\n",
    "    help=\"explainer\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--checkpoint_name\",\n",
    "    type=str,\n",
    "    required=False,\n",
    "    default=None,\n",
    "    help=\"checkpoint name (including dir)\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\"--gpu\", default=0, type=int, help=\"GPU id to use.\")\n",
    "parser.add_argument(\"--seed\", default=0, type=int, help=\"seed\")\n",
    "parser.add_argument(\n",
    "    \"--batch_size\",\n",
    "    default=32,\n",
    "    type=int,\n",
    "    help=\"batch size for protocols that do not require custom BS such as accuracy\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--nr_itrs\",\n",
    "    default=2501,\n",
    "    type=int,\n",
    "    help=\"batch size for protocols that do not require custom BS such as accuracy\",\n",
    ")\n",
    "\n",
    "parser.add_argument(\n",
    "    \"--accuracy\", default=False, action=\"store_true\", help=\"compute accuracy\"\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--controlled_synthetic_data_check\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute controlled synthetic data check\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--single_deletion\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute single deletion\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--preservation_check\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute preservation check\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--deletion_check\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute deletion check\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--target_sensitivity\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute target sensitivity\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--distractibility\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute distractibility\",\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--background_independence\",\n",
    "    default=False,\n",
    "    action=\"store_true\",\n",
    "    help=\"compute background dependence\",\n",
    ")\n",
    "# -----------------------------------\n",
    "# Parameters for gumbel trick\n",
    "# -----------------------------------\n",
    "parser.add_argument(\"--gumbel-dim\", default=-1, type=int, choices=[1, -1])\n",
    "# ----------------------\n",
    "parser.add_argument(\"--resume\", default=\"\", type=str, help=\"path of checkpoint\")\n",
    "parser.add_argument(\"--img_size\", default=256, type=int, help=\"Size of image\")\n",
    "\n",
    "\n",
    "params = \"\"\"\n",
    "--data ?  # Set the path to checkpoints\n",
    "--batch_size 24\n",
    "# =========================================\n",
    "# # --model resnet50\n",
    "# # --resume /results/funnybirds-models/resnet50_final_0_checkpoint_best.pth.tar\n",
    "# # \n",
    "# --model own_resnet50\n",
    "# --resume ? # Set the path to the checkpoint\n",
    "# --gumbel-dim -1\n",
    "# # ------------------------\n",
    "# --img_size 256\n",
    "# =========================================\n",
    "# --model vit_b_16\n",
    "# --resume /results/funnybirds-models/vit_base_patch16_224_final_1_checkpoint_best.pth.tar\n",
    "#\n",
    "--model own_vit_b_16\n",
    "--resume ? # Set the path to the checkpoint\n",
    "--gumbel-dim 1\n",
    "# ------------------------\n",
    "--img_size 224\n",
    "# =========================================\n",
    "# --explainer InputXGradient \n",
    "--explainer IntegratedGradients\n",
    "--accuracy\n",
    "--controlled_synthetic_data_check \n",
    "--target_sensitivity \n",
    "--single_deletion \n",
    "--preservation_check \n",
    "--deletion_check \n",
    "--distractibility \n",
    "--background_independence\n",
    "# =========================================\n",
    "\"\"\"\n",
    "\n",
    "args_params = list(\n",
    "    chain(\n",
    "        *map(\n",
    "            lambda x: x.split(\"#\")[0].split(),\n",
    "            filter(lambda x: x.strip() and not x.startswith(\"#\"), params.split(\"\\n\")),\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "args = parser.parse_args(args=args_params)\n",
    "args.img_size = (args.img_size,) * 2\n",
    "\n",
    "args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Device: {args.device}\")\n",
    "\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7255c81afb544c1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(args.seed)\n",
    "torch.manual_seed(args.seed)\n",
    "\n",
    "# create model\n",
    "if args.model == \"resnet50\":\n",
    "    model = resnet50(num_classes=50)\n",
    "    model = StandardModel(model)\n",
    "elif args.model == \"own_resnet50\":\n",
    "    model = own_resnet50(\n",
    "        num_classes=50,\n",
    "        gumbel_dim=args.gumbel_dim,\n",
    "        tau=0,\n",
    "    )\n",
    "    model = StandardModel(model)\n",
    "elif args.model == \"vgg16\":\n",
    "    model = vgg16(num_classes=50)\n",
    "    model = StandardModel(model)\n",
    "elif args.model.endswith(\"vit_b_16\"):\n",
    "    if args.model.startswith(\"own_\"):\n",
    "        model = own_vit_b_16(\n",
    "            num_classes=50,\n",
    "            gumbel_dim=args.gumbel_dim,\n",
    "            tau=0,\n",
    "        )\n",
    "    else:\n",
    "        if args.explainer == \"CheferLRP\":\n",
    "            model = vit_LRP(num_classes=50)\n",
    "        else:\n",
    "            model = vit_base_patch16_224(num_classes=50)\n",
    "    model = ViTModel(model)\n",
    "\n",
    "else:\n",
    "    print(\"Model not implemented\")\n",
    "\n",
    "if args.resume:\n",
    "    checkpoint = torch.load(args.resume, map_location=\"cpu\", weights_only=False)\n",
    "    print(\n",
    "        f\"Model: {checkpoint['model']}, accuracy: {checkpoint['best_acc1']:.2f}, epoch: {checkpoint['epoch']}\"\n",
    "    )\n",
    "\n",
    "    model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "    print(f\"Resuming model from file '{args.resume}'.\")\n",
    "\n",
    "model = model.to(args.device)\n",
    "model.eval()\n",
    "\n",
    "gradient_calculation = True\n",
    "\n",
    "# create explainer\n",
    "if args.explainer == \"InputXGradient\":\n",
    "    explainer = InputXGradient(model)\n",
    "    explainer = CaptumAttributionExplainer(explainer)\n",
    "elif args.explainer == \"IntegratedGradients\":\n",
    "    explainer = IntegratedGradients(model)\n",
    "    baseline = torch.zeros((1, 3, 256, 256)).to(args.device)\n",
    "    explainer = CaptumAttributionExplainer(explainer, baseline=baseline)\n",
    "elif args.explainer == \"Rollout\":\n",
    "    explainer = ViTRolloutExplainer(model)\n",
    "elif args.explainer == \"CheferLRP\":\n",
    "    explainer = ViTCheferLRPExplainer(model)\n",
    "elif args.explainer == \"CustomExplainer\":\n",
    "    ...\n",
    "else:\n",
    "    print(\"Explainer not implemented\")\n",
    "\n",
    "accuracy, csdc, pc, dc, distractibility, sd, ts = -1, -1, -1, -1, -1, -1, -1\n",
    "\n",
    "\n",
    "with torch.set_grad_enabled(gradient_calculation):\n",
    "    if args.accuracy:\n",
    "        print(\"Computing accuracy...\")\n",
    "        accuracy = accuracy_protocol(model, args)\n",
    "        accuracy = round(accuracy, 5)\n",
    "\n",
    "    if args.controlled_synthetic_data_check:\n",
    "        print(\"Computing controlled synthetic data check...\")\n",
    "        csdc = controlled_synthetic_data_check_protocol(model, explainer, args)\n",
    "\n",
    "    if args.target_sensitivity:\n",
    "        print(\"Computing target sensitivity...\")\n",
    "        ts = target_sensitivity_protocol(model, explainer, args)\n",
    "        ts = round(ts, 5)\n",
    "\n",
    "    if args.single_deletion:\n",
    "        print(\"Computing single deletion...\")\n",
    "        sd = single_deletion_protocol(model, explainer, args)\n",
    "        sd = round(sd, 5)\n",
    "\n",
    "    if args.preservation_check:\n",
    "        print(\"Computing preservation check...\")\n",
    "        pc = preservation_check_protocol(model, explainer, args)\n",
    "\n",
    "    if args.deletion_check:\n",
    "        print(\"Computing deletion check...\")\n",
    "        dc = deletion_check_protocol(model, explainer, args)\n",
    "\n",
    "    if args.distractibility:\n",
    "        print(\"Computing distractibility...\")\n",
    "        distractibility = distractibility_protocol(model, explainer, args)\n",
    "\n",
    "    if args.background_independence:\n",
    "        print(\"Computing background independence...\")\n",
    "        background_independence = background_independence_protocol(model, args)\n",
    "        background_independence = round(background_independence, 5)\n",
    "\n",
    "# select completeness and distractability thresholds such that they maximize the sum of both\n",
    "max_score = 0\n",
    "best_threshold = -1\n",
    "for threshold in csdc.keys():\n",
    "    max_score_tmp = (\n",
    "        csdc[threshold] / 3.0\n",
    "        + pc[threshold] / 3.0\n",
    "        + dc[threshold] / 3.0\n",
    "        + distractibility[threshold]\n",
    "    )\n",
    "    if max_score_tmp > max_score:\n",
    "        max_score = max_score_tmp\n",
    "        best_threshold = threshold\n",
    "\n",
    "print(\"FINAL RESULTS:\")\n",
    "print(\"Accuracy, CSDC, PC, DC, Distractability, Background independence, SD, TS\")\n",
    "print(\n",
    "    \"{}\\t{}\\t{}\\t{}\\t{}\\t{}\\t{}\\t{}\".format(\n",
    "        accuracy,\n",
    "        round(csdc[best_threshold], 5),\n",
    "        round(pc[best_threshold], 5),\n",
    "        round(dc[best_threshold], 5),\n",
    "        round(distractibility[best_threshold], 5),\n",
    "        background_independence,\n",
    "        sd,\n",
    "        ts,\n",
    "    )\n",
    ")\n",
    "print(\"Best threshold:\", best_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771a0d990eaf2d54",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1aec0d25a36cf11f",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfbe4398f00b15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from math import pi\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\", palette=\"colorblind\")\n",
    "sns.set_context(\"paper\", font_scale=2.1, rc={\"lines.linewidth\": 2.5})\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\"text.usetex\": True, \"font.family\": \"sans-serif\", \"font.sans-serif\": [\"Helvetica\"]}\n",
    ")\n",
    "# for Palatino and other serif fonts use:\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "        \"font.serif\": [\"Palatino\"],\n",
    "    }\n",
    ")\n",
    "\n",
    "\n",
    "# if your output looks like this:\n",
    "# FINAL RESULTS:\n",
    "# Accuracy, CSDC, PC, DC, Distractability, Background independence, SD, TS\n",
    "# 0.998   0.7353  0.602   0.532   0.54372 0.99826 0.54592 0.806\n",
    "# Best threshold: 0.01620253164556962\n",
    "# set results accordingly\n",
    "# results = [0.998, 0.7353, 0.602, 0.532, 0.54372, 0.99826, 0.54592, 0.806]  # SET YOUR VALUES HERE\n",
    "results = [\n",
    "    accuracy,\n",
    "    csdc[best_threshold],\n",
    "    pc[best_threshold],\n",
    "    dc[best_threshold],\n",
    "    distractibility[best_threshold],\n",
    "    background_independence,\n",
    "    sd,\n",
    "    ts,\n",
    "]\n",
    "\n",
    "acc = results[0]\n",
    "bi = results[5]\n",
    "com = ((results[1] + results[2] + results[3]) / 3 + results[4]) / 2\n",
    "cor = results[6]\n",
    "con = results[7]\n",
    "results = [acc, bi, com, cor, con]\n",
    "\n",
    "ax = plt.subplot(111, polar=True)\n",
    "\n",
    "categories = [\"Acc.\", \"B.I.\", \"Com.\", \"Cor.\", \"Con.\"]\n",
    "\n",
    "N = len(categories)\n",
    "\n",
    "# We are going to plot the first line of the data frame.\n",
    "# But we need to repeat the first value to close the circular graph:\n",
    "average = sum(results[2:]) / len(results[2:])\n",
    "results += results[:1]\n",
    "\n",
    "# What will be the angle of each axis in the plot? (we divide the plot / number of variable)\n",
    "angles = [n / float(N) * 2 * pi for n in range(N)]\n",
    "angles += angles[:1]\n",
    "\n",
    "ax.text(\n",
    "    0,\n",
    "    0,\n",
    "    (str(round(average, 2)) + \"0\")[1:4],\n",
    "    horizontalalignment=\"center\",\n",
    "    verticalalignment=\"center\",\n",
    "    size=36,\n",
    ")\n",
    "# Initialise the spider plot\n",
    "# ax = plt.subplot(111, polar=True)\n",
    "\n",
    "# Draw one axe per variable + add labels\n",
    "# ax.set_xticks(angles[:-1], categories, color='grey', size=8)\n",
    "ax.set_xticks(angles[:-1], minor=False)\n",
    "ax.set_xticklabels(categories, fontdict=None, minor=False)\n",
    "ax.tick_params(axis=\"x\", pad=10)\n",
    "\n",
    "# Draw ylabels\n",
    "# Draw ylabels\n",
    "ax.set_rlabel_position(36)\n",
    "ax.set_yticks([0.5, 1])\n",
    "ax.set_ylim(0, 1)\n",
    "\n",
    "color_string = \"#555599\"\n",
    "# Plot data\n",
    "ax.plot(angles, results, linestyle=\"solid\", color=color_string)  # , linewidth=1)\n",
    "# Fill area\n",
    "ax.fill(angles, results, color_string, alpha=0.2)\n",
    "\n",
    "\n",
    "plt.grid(color=\"#888888\", linewidth=1.5)\n",
    "circle = plt.Circle(\n",
    "    (0, 0),\n",
    "    1,\n",
    "    transform=ax.transData._b,\n",
    "    fill=False,\n",
    "    edgecolor=\"k\",\n",
    "    linewidth=3.5,\n",
    "    zorder=10,\n",
    ")\n",
    "plt.gca().add_artist(circle)\n",
    "\n",
    "plt.savefig(f\"{args.model}_{args.explainer}.png\", bbox_inches=\"tight\", dpi=300)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "635f3c73-1863-47e3-9115-bf585d936c3c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
