{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292837dc-320c-4975-9aee-50ff7a562adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from typing import Tuple, List, Dict\n",
    "\n",
    "from cl_explain.metrics.ablation import compute_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e71332d3-f889-49bc-a192-9ba94fac8850",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULT_PATH = \"results\"\n",
    "SUPERPIXEL_ATTRIBUTION_METHODS = [\"kernel_shap\"]\n",
    "SEED_LIST = [123, 456, 789, 42, 91]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbfd4084-4793-44b7-9371-4dddd543b073",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eval_filename(\n",
    "    different_classes: bool,\n",
    "    comprehensive: bool,\n",
    "    corpus_size: int,\n",
    "    explanation_name: str,\n",
    "    foil_size: int,\n",
    "    explicand_size: int,\n",
    "    attribution_name: str,\n",
    "    superpixel_dim: int,\n",
    "    removal: str,\n",
    "    blur_strength: float,\n",
    "    eval_superpixel_dim: int,\n",
    "    eval_foil_size: int,\n",
    "    take_attribution_abs: bool,\n",
    ") -> str:\n",
    "    \"\"\"Get eval filename.\"\"\"\n",
    "    if different_classes:\n",
    "        eval_filename = \"diff_class\"\n",
    "    else:\n",
    "        eval_filename = \"same_class\"\n",
    "    if comprehensive:\n",
    "        eval_filename += \"_comprehensive\"\n",
    "        \n",
    "    eval_filename += \"_eval_results\"\n",
    "    eval_filename += f\"_explicand_size={explicand_size}\"\n",
    "    if \"corpus\" in explanation_name:\n",
    "        eval_filename += f\"_corpus_size={corpus_size}\"\n",
    "    if \"contrastive\" in explanation_name:\n",
    "        eval_filename += f\"_foil_size={foil_size}\"\n",
    "    if attribution_name in SUPERPIXEL_ATTRIBUTION_METHODS:\n",
    "        eval_filename += f\"_superpixel_dim={superpixel_dim}\"\n",
    "    eval_filename += f\"_removal={removal}\"\n",
    "    if removal == \"blurring\":\n",
    "        eval_filename += f\"_blur_strength={blur_strength:.1f}\"\n",
    "    eval_filename += f\"_eval_superpixel_dim={eval_superpixel_dim}\"\n",
    "    if not comprehensive:\n",
    "        eval_filename += f\"_eval_foil_size={eval_foil_size}\"\n",
    "    if take_attribution_abs:\n",
    "        eval_filename += \"_abs\"\n",
    "    eval_filename += \".pkl\"\n",
    "    return eval_filename\n",
    "\n",
    "\n",
    "def get_mean_curves(outputs, curve_kind) -> Tuple[List[torch.Tensor], int]:\n",
    "    available_curve_kinds = [\"insertion\", \"deletion\"]\n",
    "    assert curve_kind in available_curve_kinds, (\n",
    "        f\"curve_kind={curve_kind} is not one of {available_curve_kinds}!\"\n",
    "    )\n",
    "    target_list = [key for key in outputs.keys()]\n",
    "    eval_name_list = (\n",
    "        outputs[target_list[0]][\"eval_model_names\"]\n",
    "        + outputs[target_list[0]][\"eval_measure_names\"]\n",
    "    )\n",
    "    eval_mean_curve_dict = {}\n",
    "    for j, eval_name in enumerate(eval_name_list):\n",
    "        \n",
    "        curve_list = []\n",
    "        num_features = None\n",
    "\n",
    "        for target, output in outputs.items():\n",
    "            target_curve_list = (\n",
    "                output[f\"model_{curve_kind}_curves\"]\n",
    "                + output[f\"measure_{curve_kind}_curves\"]\n",
    "            )\n",
    "            curve_list.append(target_curve_list[j])\n",
    "            num_features = output[f\"{curve_kind}_num_features\"]\n",
    "        \n",
    "        curves = torch.cat(curve_list)\n",
    "        mean_curve = curves.mean(dim=0).cpu()\n",
    "        eval_mean_curve_dict[eval_name] = mean_curve\n",
    "        \n",
    "    return eval_mean_curve_dict, num_features\n",
    "\n",
    "\n",
    "def get_auc_stats(\n",
    "    dataset: str,\n",
    "    encoder: str,\n",
    "    attribution: str,\n",
    "    eval_name: str,\n",
    "    normalize_similarity: bool,\n",
    "    different_classes: bool,\n",
    "    comprehensive: bool = False,\n",
    "    explicand_size: int = 25,\n",
    "    removal: str = \"blurring\",\n",
    "    blur_strength: float = 5.0,\n",
    "    superpixel_dim: int = 1,\n",
    "    eval_superpixel_dim: int = 1,\n",
    "    foil_size: int = 1500,\n",
    "    corpus_size: int = 100,\n",
    "    eval_foil_size: int = 1500,\n",
    "    take_attribution_abs: bool = False,\n",
    ") -> Dict[str, Dict[str, List]]:\n",
    "    if attribution == \"random_baseline\":\n",
    "        explanation_list = [\"self_weighted\"]\n",
    "    else:\n",
    "        explanation_list = [  # Make sure to order this way.\n",
    "            \"self_weighted\",\n",
    "            \"contrastive_self_weighted\",\n",
    "            \"corpus\",\n",
    "            \"contrastive_corpus\",\n",
    "        ]\n",
    "        \n",
    "    insertion_mean_list = []\n",
    "    insertion_ci_list = []\n",
    "    deletion_mean_list = []\n",
    "    deletion_ci_list = []\n",
    "\n",
    "    for explanation in explanation_list:\n",
    "        insertion_list = []\n",
    "        deletion_list = []\n",
    "        for seed in SEED_LIST:            \n",
    "            eval_filename = get_eval_filename(\n",
    "                different_classes=different_classes,\n",
    "                comprehensive=comprehensive,\n",
    "                corpus_size=corpus_size,\n",
    "                explanation_name=explanation,\n",
    "                foil_size=foil_size,\n",
    "                explicand_size=explicand_size,\n",
    "                attribution_name=attribution,\n",
    "                superpixel_dim=superpixel_dim,\n",
    "                removal=removal,\n",
    "                blur_strength=blur_strength,\n",
    "                eval_superpixel_dim=eval_superpixel_dim,\n",
    "                eval_foil_size=eval_foil_size,\n",
    "                take_attribution_abs=take_attribution_abs,\n",
    "            )\n",
    "\n",
    "            if normalize_similarity:\n",
    "                method_name = \"normalized\"\n",
    "            else:\n",
    "                method_name = \"unnormalized\"\n",
    "            method_name += f\"_{explanation}_{attribution}\"\n",
    "            with open(\n",
    "                os.path.join(\n",
    "                    RESULT_PATH,\n",
    "                    dataset,\n",
    "                    encoder,\n",
    "                    method_name,\n",
    "                    f\"{seed}\",\n",
    "                    eval_filename,\n",
    "                ),\n",
    "                \"rb\",\n",
    "            ) as handle:\n",
    "                outputs = pickle.load(handle)\n",
    "            insertion_curve_dict, insertion_num_features = get_mean_curves(\n",
    "                outputs, \"insertion\"\n",
    "            )\n",
    "            deletion_curve_dict, deletion_num_features = get_mean_curves(\n",
    "                outputs, \"deletion\"\n",
    "            )\n",
    "            insertion_list.append(\n",
    "                compute_auc(\n",
    "                    curve=insertion_curve_dict[eval_name],\n",
    "                    num_features=insertion_num_features,\n",
    "                )\n",
    "            )\n",
    "            deletion_list.append(\n",
    "                compute_auc(\n",
    "                    curve=deletion_curve_dict[eval_name],\n",
    "                    num_features=deletion_num_features,\n",
    "                )\n",
    "            )\n",
    "        insertion_mean_list.append(np.mean(insertion_list))\n",
    "        insertion_ci_list.append(1.96 * np.std(insertion_list) / np.sqrt(len(SEED_LIST)))\n",
    "        deletion_mean_list.append(np.mean(deletion_list))\n",
    "        deletion_ci_list.append(1.96 * np.std(deletion_list) / np.sqrt(len(SEED_LIST)))\n",
    "    return {\n",
    "        \"insertion\": {\"mean\": insertion_mean_list, \"ci\": insertion_ci_list},\n",
    "        \"deletion\": {\"mean\": deletion_mean_list, \"ci\": deletion_ci_list},\n",
    "    }\n",
    "\n",
    "\n",
    "def get_formatted_aucs(\n",
    "    insertion_direction: str,\n",
    "    deletion_direction: str,\n",
    "    bold_best: bool = True,\n",
    "    **kwargs,\n",
    "):\n",
    "    auc_stats = get_auc_stats(**kwargs)\n",
    "    \n",
    "    insertion_mean_list = auc_stats[\"insertion\"][\"mean\"]\n",
    "    insertion_ci_list = auc_stats[\"insertion\"][\"ci\"]\n",
    "    if insertion_direction == \"max\":\n",
    "        insertion_best_idx = np.argmax(insertion_mean_list)\n",
    "    elif insertion_direction == \"min\":\n",
    "        insertion_best_idx = np.argmin(insertion_mean_list)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"insertion_direction={insertion_direction} should be max or min!\"\n",
    "        )\n",
    "        \n",
    "    deletion_mean_list = auc_stats[\"deletion\"][\"mean\"]\n",
    "    deletion_ci_list = auc_stats[\"deletion\"][\"ci\"]\n",
    "    if deletion_direction == \"max\":\n",
    "        deletion_best_idx = np.argmax(deletion_mean_list)\n",
    "    elif deletion_direction == \"min\":\n",
    "        deletion_best_idx = np.argmin(deletion_mean_list)\n",
    "    else:\n",
    "        raise ValueError(\n",
    "            f\"deietion_direction={deietion_direction} should be max or min!\"\n",
    "        )\n",
    "    \n",
    "    text_list = []\n",
    "    for i in range(len(insertion_mean_list)):\n",
    "        insertion_mean = insertion_mean_list[i]\n",
    "        insertion_ci = insertion_ci_list[i]\n",
    "        if np.abs(insertion_mean) < 0.01:\n",
    "            insertion_text = (\n",
    "                \"{:.2e}\".format(insertion_mean)\n",
    "                + \" $\\pm$ \"\n",
    "                + \"{:.2e}\".format(insertion_ci)\n",
    "            )\n",
    "        else:\n",
    "            insertion_text = f\"{insertion_mean:.3f} ({insertion_ci:.3f})\"\n",
    "        if i == insertion_best_idx and bold_best:\n",
    "            insertion_text = \"\\\\textbf{\" + insertion_text + \"}\"\n",
    "            \n",
    "        deletion_mean = deletion_mean_list[i]\n",
    "        deletion_ci = deletion_ci_list[i]\n",
    "        if np.abs(deletion_mean) < 0.01:\n",
    "            deletion_text = (\n",
    "                \"{:.2e}\".format(deletion_mean)\n",
    "                + \" $\\pm$ \"\n",
    "                + \"{:.2e}\".format(deletion_ci)\n",
    "            )\n",
    "        else:\n",
    "            deletion_text = f\"{deletion_mean:.3f} ({deletion_ci:.3f})\"\n",
    "        if i == deletion_best_idx and bold_best:\n",
    "            deletion_text = \"\\\\textbf{\" + deletion_text + \"}\"\n",
    "            \n",
    "        text = insertion_text + \" & \" + deletion_text\n",
    "        text_list.append(text)\n",
    "    return text_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13745d86-38e1-44b8-a15e-090a4107192b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_aucs(\n",
    "    eval_name: str,\n",
    "    normalize_similarity: bool,\n",
    "    different_classes: bool,\n",
    "    insertion_direction: str = \"max\",\n",
    "    deletion_direction: str = \"min\",\n",
    "    comprehensive: bool = False,\n",
    "):\n",
    "    attribution_list = [\"int_grad\", \"gradient_shap\", \"rise\"]\n",
    "    dataset_encoder_combos = [\n",
    "        (\"imagenet\", \"simclr_x1\"),\n",
    "        (\"cifar\", \"simsiam_18\"),\n",
    "        (\"mura\", \"classifier_18\"),\n",
    "    ]\n",
    "    for attribution in attribution_list:\n",
    "        print(attribution)\n",
    "        print(\"-\" * len(attribution))\n",
    "        label_free_text = \"Label-Free\"\n",
    "        contrastive_label_free_text = \"Contrastive Label-Free\"\n",
    "        corpus_text = \"Corpus\"\n",
    "        cocoa_text = \"COCOA\"\n",
    "        for dataset_encoder in dataset_encoder_combos:\n",
    "            text_list = get_formatted_aucs(\n",
    "                insertion_direction=insertion_direction,\n",
    "                deletion_direction=deletion_direction,\n",
    "                dataset=dataset_encoder[0],\n",
    "                encoder=dataset_encoder[1],\n",
    "                attribution=attribution,\n",
    "                eval_name=eval_name,\n",
    "                normalize_similarity=normalize_similarity,\n",
    "                different_classes=different_classes,\n",
    "                comprehensive=comprehensive,\n",
    "            )\n",
    "            label_free_text += f\" & {text_list[0]}\"\n",
    "            contrastive_label_free_text += f\" & {text_list[1]}\"\n",
    "            corpus_text += f\" & {text_list[2]}\"\n",
    "            cocoa_text += f\" & {text_list[3]}\"\n",
    "        print(label_free_text + \" \\\\\\\\\")\n",
    "        print(contrastive_label_free_text + \" \\\\\\\\\")\n",
    "        print(corpus_text + \" \\\\\\\\\")\n",
    "        print(cocoa_text + \" \\\\\\\\\")\n",
    "        print(\"\")\n",
    "\n",
    "    print(\"random\")\n",
    "    print(\"------\")\n",
    "    random_text = \"Random\"\n",
    "    for dataset_encoder in dataset_encoder_combos:\n",
    "        text_list = get_formatted_aucs(\n",
    "            insertion_direction=insertion_direction,\n",
    "            deletion_direction=deletion_direction,\n",
    "            bold_best=False,\n",
    "            dataset=dataset_encoder[0],\n",
    "            encoder=dataset_encoder[1],\n",
    "            attribution=\"random_baseline\",\n",
    "            eval_name=eval_name,\n",
    "            normalize_similarity=True,  # Does not matter for random baseline.\n",
    "            different_classes=different_classes,\n",
    "            comprehensive=comprehensive,\n",
    "        )\n",
    "        random_text += f\" & {text_list[0]}\"\n",
    "    print(random_text + \" \\\\\\\\\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61bd35bd-28ce-4547-b576-cc4edf54d3e6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Corpus Majority Probability (Cosine Similarity & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74daec49-1e13-44be-8b72-6bf225b122aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6d558e1-19aa-404a-90fc-9fceb057500c",
   "metadata": {},
   "source": [
    "## Contrastive Corpus Similarity (Cosine Similarity & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "084d2256-6d59-4693-be11-4a95ed4dd873",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=False,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f6e4dba-34c0-44e9-a980-552d6146d8c3",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Corpus Majority Probability (Cosine Similarity & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "084d4916-7070-47e9-8697-50cd0104dda6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "131acfd0-bcd9-4f9a-b2ed-95fda88cf5f9",
   "metadata": {},
   "source": [
    "## Contrastive Corpus Similarity (Cosine Similarity & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b034d40-2413-46be-9e49-95b64cce4119",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=True,\n",
    "    different_classes=True,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de7954e4-5aba-4919-bb44-dc9dbdc1bd2d",
   "metadata": {},
   "source": [
    "## Corpus Majority Probability (Dot Product & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85309019-6998-4361-9361-8ef0ba11193c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=False,\n",
    "    different_classes=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bd5593b-1d93-4c5e-8f30-1349707f4e43",
   "metadata": {},
   "source": [
    "## Contrastive Corpus Similarity (Dot Product & Same Class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c133583b-5c80-44d1-b789-deef5a8d3fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=False,\n",
    "    different_classes=False,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "394012d8-419e-415d-a537-e62b4b707019",
   "metadata": {},
   "source": [
    "## Corpus Majority Probability (Dot Product & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46f0cd14-5942-4de4-b4e2-b320345ff04d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"corpus_majority_prob\",\n",
    "    normalize_similarity=False,\n",
    "    different_classes=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c327edbd-3ffe-4253-9c4c-a12fdd114f8d",
   "metadata": {},
   "source": [
    "## Contrastive Corpus Similarity (Dot Product & Different Classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9781a50b-2cad-4075-a724-eb4977f5428d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_aucs(\n",
    "    eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "    normalize_similarity=False,\n",
    "    different_classes=True,\n",
    "    comprehensive=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c8b922-989e-48c2-8c1a-2495c1437b8e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
