{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292837dc-320c-4975-9aee-50ff7a562adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from typing import Tuple, List, Dict\n",
    "\n",
    "from cl_explain.metrics.ablation import compute_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e71332d3-f889-49bc-a192-9ba94fac8850",
   "metadata": {},
   "outputs": [],
   "source": [
    "RANDOMIZED_RESULT_PATH = \"results\"\n",
    "ORIGINAL_RESULT_PATH = \"results\"\n",
    "SUPERPIXEL_ATTRIBUTION_METHODS = [\"kernel_shap\"]\n",
    "SEED_LIST = [123, 456, 789, 42, 91]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbfd4084-4793-44b7-9371-4dddd543b073",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eval_filename(\n",
    "    different_classes: bool,\n",
    "    comprehensive: bool,\n",
    "    corpus_size: int,\n",
    "    explanation_name: str,\n",
    "    foil_size: int,\n",
    "    explicand_size: int,\n",
    "    attribution_name: str,\n",
    "    superpixel_dim: int,\n",
    "    removal: str,\n",
    "    blur_strength: float,\n",
    "    eval_superpixel_dim: int,\n",
    "    eval_foil_size: int,\n",
    "    take_attribution_abs: bool,\n",
    ") -> str:\n",
    "    \"\"\"Get eval filename.\"\"\"\n",
    "    if different_classes:\n",
    "        eval_filename = \"diff_class\"\n",
    "    else:\n",
    "        eval_filename = \"same_class\"\n",
    "    if comprehensive:\n",
    "        eval_filename += \"_comprehensive\"\n",
    "        \n",
    "    eval_filename += \"_eval_results\"\n",
    "    eval_filename += f\"_explicand_size={explicand_size}\"\n",
    "    if \"corpus\" in explanation_name:\n",
    "        eval_filename += f\"_corpus_size={corpus_size}\"\n",
    "    if \"contrastive\" in explanation_name:\n",
    "        eval_filename += f\"_foil_size={foil_size}\"\n",
    "    if attribution_name in SUPERPIXEL_ATTRIBUTION_METHODS:\n",
    "        eval_filename += f\"_superpixel_dim={superpixel_dim}\"\n",
    "    eval_filename += f\"_removal={removal}\"\n",
    "    if removal == \"blurring\":\n",
    "        eval_filename += f\"_blur_strength={blur_strength:.1f}\"\n",
    "    eval_filename += f\"_eval_superpixel_dim={eval_superpixel_dim}\"\n",
    "    if not comprehensive:\n",
    "        eval_filename += f\"_eval_foil_size={eval_foil_size}\"\n",
    "    if take_attribution_abs:\n",
    "        eval_filename += \"_abs\"\n",
    "    eval_filename += \".pkl\"\n",
    "    return eval_filename\n",
    "\n",
    "\n",
    "def get_mean_curves(outputs, curve_kind) -> Tuple[List[torch.Tensor], int]:\n",
    "    available_curve_kinds = [\"insertion\", \"deletion\"]\n",
    "    assert curve_kind in available_curve_kinds, (\n",
    "        f\"curve_kind={curve_kind} is not one of {available_curve_kinds}!\"\n",
    "    )\n",
    "    target_list = [key for key in outputs.keys()]\n",
    "    eval_name_list = (\n",
    "        outputs[target_list[0]][\"eval_model_names\"]\n",
    "        + outputs[target_list[0]][\"eval_measure_names\"]\n",
    "    )\n",
    "    eval_mean_curve_dict = {}\n",
    "    for j, eval_name in enumerate(eval_name_list):\n",
    "        \n",
    "        curve_list = []\n",
    "        num_features = None\n",
    "\n",
    "        for target, output in outputs.items():\n",
    "            target_curve_list = (\n",
    "                output[f\"model_{curve_kind}_curves\"]\n",
    "                + output[f\"measure_{curve_kind}_curves\"]\n",
    "            )\n",
    "            curve_list.append(target_curve_list[j])\n",
    "            num_features = output[f\"{curve_kind}_num_features\"]\n",
    "        \n",
    "        curves = torch.cat(curve_list)\n",
    "        mean_curve = curves.mean(dim=0).cpu()\n",
    "        eval_mean_curve_dict[eval_name] = mean_curve\n",
    "        \n",
    "    return eval_mean_curve_dict, num_features\n",
    "\n",
    "\n",
    "def get_auc_stats(\n",
    "    dataset: str,\n",
    "    encoder: str,\n",
    "    attribution: str,\n",
    "    eval_name: str,\n",
    "    normalize_similarity: bool,\n",
    "    different_classes: bool,\n",
    "    comprehensive: bool = False,\n",
    "    explicand_size: int = 25,\n",
    "    removal: str = \"blurring\",\n",
    "    blur_strength: float = 5.0,\n",
    "    superpixel_dim: int = 1,\n",
    "    eval_superpixel_dim: int = 1,\n",
    "    foil_size: int = 1500,\n",
    "    corpus_size: int = 100,\n",
    "    eval_foil_size: int = 1500,\n",
    "    take_attribution_abs: bool = False,\n",
    ") -> Dict[str, Dict[str, List]]:\n",
    "    if attribution == \"random_baseline\":\n",
    "        explanation_list = [\"self_weighted\"]\n",
    "    else:\n",
    "        explanation_list = [\n",
    "            \"contrastive_corpus\",\n",
    "            \"randomized_model_contrastive_corpus\"\n",
    "        ]\n",
    "        \n",
    "    insertion_mean_list = []\n",
    "    insertion_ci_list = []\n",
    "    deletion_mean_list = []\n",
    "    deletion_ci_list = []\n",
    "\n",
    "    for explanation in explanation_list:\n",
    "        if explanation.startswith(\"randomized_model_\"):\n",
    "            explanation_name = explanation.replace(\"randomized_model_\", \"\")\n",
    "            randomize_model = True\n",
    "            result_path = RANDOMIZED_RESULT_PATH\n",
    "        else:\n",
    "            explanation_name = explanation\n",
    "            randomize_model = False\n",
    "            result_path = ORIGINAL_RESULT_PATH\n",
    "            \n",
    "        insertion_list = []\n",
    "        deletion_list = []\n",
    "        for seed in SEED_LIST:            \n",
    "            eval_filename = get_eval_filename(\n",
    "                different_classes=different_classes,\n",
    "                comprehensive=comprehensive,\n",
    "                corpus_size=corpus_size,\n",
    "                explanation_name=explanation_name,\n",
    "                foil_size=foil_size,\n",
    "                explicand_size=explicand_size,\n",
    "                attribution_name=attribution,\n",
    "                superpixel_dim=superpixel_dim,\n",
    "                removal=removal,\n",
    "                blur_strength=blur_strength,\n",
    "                eval_superpixel_dim=eval_superpixel_dim,\n",
    "                eval_foil_size=eval_foil_size,\n",
    "                take_attribution_abs=take_attribution_abs,\n",
    "            )\n",
    "\n",
    "            if normalize_similarity:\n",
    "                method_name = \"normalized\"\n",
    "            else:\n",
    "                method_name = \"unnormalized\"\n",
    "            if randomize_model:\n",
    "                method_name = \"randomized_model_\" + method_name\n",
    "            method_name += f\"_{explanation_name}_{attribution}\"\n",
    "            with open(\n",
    "                os.path.join(\n",
    "                    result_path,\n",
    "                    dataset,\n",
    "                    encoder,\n",
    "                    method_name,\n",
    "                    f\"{seed}\",\n",
    "                    eval_filename,\n",
    "                ),\n",
    "                \"rb\",\n",
    "            ) as handle:\n",
    "                outputs = pickle.load(handle)\n",
    "            insertion_curve_dict, insertion_num_features = get_mean_curves(\n",
    "                outputs, \"insertion\"\n",
    "            )\n",
    "            deletion_curve_dict, deletion_num_features = get_mean_curves(\n",
    "                outputs, \"deletion\"\n",
    "            )\n",
    "            insertion_list.append(\n",
    "                compute_auc(\n",
    "                    curve=insertion_curve_dict[eval_name],\n",
    "                    num_features=insertion_num_features,\n",
    "                )\n",
    "            )\n",
    "            deletion_list.append(\n",
    "                compute_auc(\n",
    "                    curve=deletion_curve_dict[eval_name],\n",
    "                    num_features=deletion_num_features,\n",
    "                )\n",
    "            )\n",
    "        insertion_mean_list.append(np.mean(insertion_list))\n",
    "        insertion_ci_list.append(1.96 * np.std(insertion_list) / np.sqrt(len(SEED_LIST)))\n",
    "        deletion_mean_list.append(np.mean(deletion_list))\n",
    "        deletion_ci_list.append(1.96 * np.std(deletion_list) / np.sqrt(len(SEED_LIST)))\n",
    "    return {\n",
    "        \"insertion\": {\"mean\": insertion_mean_list, \"ci\": insertion_ci_list},\n",
    "        \"deletion\": {\"mean\": deletion_mean_list, \"ci\": deletion_ci_list},\n",
    "    }\n",
    "\n",
    "\n",
    "def get_formatted_aucs(\n",
    "    insertion_direction: str,\n",
    "    deletion_direction: str,\n",
    "    bold_best: bool = True,\n",
    "    **kwargs,\n",
    "):\n",
    "    auc_stats = get_auc_stats(**kwargs)\n",
    "    \n",
    "    insertion_mean_list = auc_stats[\"insertion\"][\"mean\"]\n",
    "    insertion_ci_list = auc_stats[\"insertion\"][\"ci\"]\n",
    "    if insertion_direction == \"max\":\n",
    "        insertion_best_idx = np.argmax(insertion_mean_list)\n",
    "    elif insertion_direction == \"min\":\n",
    "        insertion_best_idx = np.argmin(insertion_mean_list)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"insertion_direction={insertion_direction} should be max or min!\"\n",
    "        )\n",
    "        \n",
    "    deletion_mean_list = auc_stats[\"deletion\"][\"mean\"]\n",
    "    deletion_ci_list = auc_stats[\"deletion\"][\"ci\"]\n",
    "    if deletion_direction == \"max\":\n",
    "        deletion_best_idx = np.argmax(deletion_mean_list)\n",
    "    elif deletion_direction == \"min\":\n",
    "        deletion_best_idx = np.argmin(deletion_mean_list)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"deietion_direction={deietion_direction} should be max or min!\"\n",
    "        )\n",
    "    \n",
    "    text_list = []\n",
    "    for i in range(len(insertion_mean_list)):\n",
    "        insertion_mean = insertion_mean_list[i]\n",
    "        insertion_ci = insertion_ci_list[i]\n",
    "        if np.abs(insertion_mean) < 0.01:\n",
    "            insertion_text = (\n",
    "                \"{:.2e}\".format(insertion_mean)\n",
    "                + \" $\\pm$ \"\n",
    "                + \"{:.2e}\".format(insertion_ci)\n",
    "            )\n",
    "        else:\n",
    "            insertion_text = f\"{insertion_mean:.3f} ({insertion_ci:.3f})\"\n",
    "        if i == insertion_best_idx and bold_best:\n",
    "            insertion_text = \"\\\\textbf{\" + insertion_text + \"}\"\n",
    "            \n",
    "        deletion_mean = deletion_mean_list[i]\n",
    "        deletion_ci = deletion_ci_list[i]\n",
    "        if np.abs(deletion_mean) < 0.01:\n",
    "            deletion_text = (\n",
    "                \"{:.2e}\".format(deletion_mean)\n",
    "                + \" $\\pm$ \"\n",
    "                + \"{:.2e}\".format(deletion_ci)\n",
    "            )\n",
    "        else:\n",
    "            deletion_text = f\"{deletion_mean:.3f} ({deletion_ci:.3f})\"\n",
    "        if i == deletion_best_idx and bold_best:\n",
    "            deletion_text = \"\\\\textbf{\" + deletion_text + \"}\"\n",
    "            \n",
    "        text = insertion_text + \" & \" + deletion_text\n",
    "        text_list.append(text)\n",
    "    return text_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13745d86-38e1-44b8-a15e-090a4107192b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_aucs(\n",
    "    eval_name: str,\n",
    "    normalize_similarity: bool,\n",
    "    different_classes: bool,\n",
    "    insertion_direction: str = \"max\",\n",
    "    deletion_direction: str = \"min\",\n",
    "    comprehensive: bool = True,\n",
    "):\n",
    "    attribution_list = [\"int_grad\", \"gradient_shap\", \"rise\"]\n",
    "    dataset_encoder_combos = [\n",
    "        (\"imagenet\", \"simclr_x1\"),\n",
    "        (\"cifar\", \"simsiam_18\"),\n",
    "        (\"mura\", \"classifier_18\"),\n",
    "    ]\n",
    "    for attribution in attribution_list:\n",
    "        print(attribution)\n",
    "        print(\"-\" * len(attribution))\n",
    "        original_model_cocoa_text = \"COCOA (trained model)\"\n",
    "        randomized_model_cocoa_text = \"COCOA (randomized model)\"\n",
    "        for dataset_encoder in dataset_encoder_combos:\n",
    "            text_list = get_formatted_aucs(\n",
    "                insertion_direction=insertion_direction,\n",
    "                deletion_direction=deletion_direction,\n",
    "                dataset=dataset_encoder[0],\n",
    "                encoder=dataset_encoder[1],\n",
    "                attribution=attribution,\n",
    "                eval_name=eval_name,\n",
    "                normalize_similarity=normalize_similarity,\n",
    "                different_classes=different_classes,\n",
    "                comprehensive=comprehensive,\n",
    "            )\n",
    "            original_model_cocoa_text += f\" & {text_list[0]}\"\n",
    "            randomized_model_cocoa_text += f\" & {text_list[1]}\"\n",
    "        print(original_model_cocoa_text + \" \\\\\\\\\")\n",
    "        print(randomized_model_cocoa_text + \"\\\\\\\\\")\n",
    "        print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa9684f8-eed0-4005-a493-cd3576936ed2",
   "metadata": {},
   "source": [
    "## Randomized Model Corpus Majority Probability (Cosine Similarity & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3307d15-6d46-49a1-a1ba-f2ecc99fde9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ed292bb-4445-47f5-851a-9327544b480a",
   "metadata": {},
   "source": [
    "## Randomized Model Corpus Majority Probability (Cosine Similarity & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c705516f-3e4f-48e0-a741-e514546a35ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90c844c-0dd3-4984-9a79-9e998e2e9c7d",
   "metadata": {},
   "source": [
    "## Randomized Model Contrastive Corpus Similarity (Cosine Similarity & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d997f9-9082-4a8c-9d5e-8c06b975594a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=False,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5505ba0c-124b-4c77-8649-8c24ac8b63a7",
   "metadata": {},
   "source": [
    "## Randomized Model Contrastive Corpus Similarity (Cosine Similarity & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9805ba2e-1802-410a-aaca-bb06bc7d9efe",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=True,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c8b922-989e-48c2-8c1a-2495c1437b8e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
