{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292837dc-320c-4975-9aee-50ff7a562adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from typing import Tuple, List, Dict\n",
    "\n",
    "from cl_explain.metrics.ablation import compute_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e71332d3-f889-49bc-a192-9ba94fac8850",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULT_PATH = \"results\"\n",
    "SUPERPIXEL_ATTRIBUTION_METHODS = [\"kernel_shap\"]\n",
    "SEED_LIST = [123, 456, 789, 42, 91]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbfd4084-4793-44b7-9371-4dddd543b073",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eval_filename(\n",
    "    same_class: bool,\n",
    "    comprehensive: bool,\n",
    "    corpus_size: int,\n",
    "    explanation_name: str,\n",
    "    foil_size: int,\n",
    "    explicand_size: int,\n",
    "    attribution_name: str,\n",
    "    superpixel_dim: int,\n",
    "    removal: str,\n",
    "    blur_strength: float,\n",
    "    eval_superpixel_dim: int,\n",
    "    eval_foil_size: int,\n",
    "    take_attribution_abs: bool,\n",
    ") -> str:\n",
    "    \"\"\"Get eval filename.\"\"\"\n",
    "    if same_class:\n",
    "        eval_filename = \"same_class\"\n",
    "    else:\n",
    "        eval_filename = \"diff_class\"\n",
    "    if comprehensive:\n",
    "        eval_filename += \"_comprehensive\"\n",
    "    eval_filename += \"_eval_results\"\n",
    "    \n",
    "    eval_filename += f\"_explicand_size={explicand_size}\"\n",
    "    if \"corpus\" in explanation_name:\n",
    "        eval_filename += f\"_corpus_size={corpus_size}\"\n",
    "    if \"contrastive\" in explanation_name:\n",
    "        eval_filename += f\"_foil_size={foil_size}\"\n",
    "    if attribution_name in SUPERPIXEL_ATTRIBUTION_METHODS:\n",
    "        eval_filename += f\"_superpixel_dim={superpixel_dim}\"\n",
    "    eval_filename += f\"_removal={removal}\"\n",
    "    if removal == \"blurring\":\n",
    "        eval_filename += f\"_blur_strength={blur_strength:.1f}\"\n",
    "    eval_filename += f\"_eval_superpixel_dim={eval_superpixel_dim}\"\n",
    "    if not comprehensive:\n",
    "        eval_filename += f\"_eval_foil_size={eval_foil_size}\"\n",
    "    if take_attribution_abs:\n",
    "        eval_filename += \"_abs\"\n",
    "    eval_filename += \".pkl\"\n",
    "    return eval_filename\n",
    "\n",
    "\n",
    "def get_mean_curves(outputs, curve_kind) -> Tuple[List[torch.Tensor], int]:\n",
    "    available_curve_kinds = [\"insertion\", \"deletion\"]\n",
    "    assert curve_kind in available_curve_kinds, (\n",
    "        f\"curve_kind={curve_kind} is not one of {available_curve_kinds}!\"\n",
    "    )\n",
    "    target_list = [key for key in outputs.keys()]\n",
    "    eval_name_list = (\n",
    "        outputs[target_list[0]][\"eval_model_names\"]\n",
    "        + outputs[target_list[0]][\"eval_measure_names\"]\n",
    "    )\n",
    "    eval_mean_curve_dict = {}\n",
    "    for j, eval_name in enumerate(eval_name_list):\n",
    "        \n",
    "        curve_list = []\n",
    "        num_features = None\n",
    "\n",
    "        for target, output in outputs.items():\n",
    "            target_curve_list = (\n",
    "                output[f\"model_{curve_kind}_curves\"]\n",
    "                + output[f\"measure_{curve_kind}_curves\"]\n",
    "            )\n",
    "            curve_list.append(target_curve_list[j])\n",
    "            num_features = output[f\"{curve_kind}_num_features\"]\n",
    "        \n",
    "        curves = torch.cat(curve_list)\n",
    "        mean_curve = curves.mean(dim=0).cpu()\n",
    "        eval_mean_curve_dict[eval_name] = mean_curve\n",
    "        \n",
    "    return eval_mean_curve_dict, num_features\n",
    "\n",
    "\n",
    "def get_mean_aucs(outputs, curve_kind) -> Dict[str, float]:\n",
    "    available_curve_kinds = [\"insertion\", \"deletion\"]\n",
    "    assert curve_kind in available_curve_kinds, (\n",
    "        f\"curve_kind={curve_kind} is not one of {available_curve_kinds}!\"\n",
    "    )\n",
    "    target_list = [key for key in outputs.keys()]\n",
    "    eval_name_list = (\n",
    "        outputs[target_list[0]][\"eval_model_names\"]\n",
    "        + outputs[target_list[0]][\"eval_measure_names\"]\n",
    "    )\n",
    "    mean_auc_dict = {}\n",
    "    for j, eval_name in enumerate(eval_name_list):\n",
    "        \n",
    "        curve_list = []\n",
    "        num_features = None\n",
    "\n",
    "        for target, output in outputs.items():\n",
    "            target_curve_list = (\n",
    "                output[f\"model_{curve_kind}_curves\"]\n",
    "                + output[f\"measure_{curve_kind}_curves\"]\n",
    "            )\n",
    "            curve_list.append(target_curve_list[j])\n",
    "            num_features = output[f\"{curve_kind}_num_features\"]\n",
    "        \n",
    "        curves = torch.cat(curve_list)\n",
    "        mean_curve = curves.mean(dim=0).cpu()\n",
    "        mean_auc = compute_auc(mean_curve, num_features)\n",
    "        mean_auc_dict[eval_name] = mean_auc.numpy()\n",
    "    return mean_auc_dict\n",
    "\n",
    "\n",
    "def get_auc_stats(\n",
    "    dataset: str,\n",
    "    encoder: str,\n",
    "    explanation: str,\n",
    "    attribution: str,\n",
    "    eval_name_list: List[str],\n",
    "    seed_list: List[int],\n",
    "    normalize_similarity: bool,\n",
    "    **get_eval_filename_kwargs,\n",
    ") -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:\n",
    "    all_insertion_auc_dict = {eval_name: [] for eval_name in eval_name_list}\n",
    "    all_deletion_auc_dict = {eval_name: [] for eval_name in eval_name_list}\n",
    "    for seed in seed_list:\n",
    "        eval_filename = get_eval_filename(\n",
    "            explanation_name=explanation,\n",
    "            attribution_name=attribution,\n",
    "            **get_eval_filename_kwargs,\n",
    "        )\n",
    "        if normalize_similarity:\n",
    "            method_name = f\"normalized_{explanation}_{attribution}\"\n",
    "        else:\n",
    "            method_name = f\"unnormalized_{explanation}_{attribution}\"\n",
    "        with open(\n",
    "            os.path.join(\n",
    "                RESULT_PATH,\n",
    "                dataset,\n",
    "                encoder,\n",
    "                method_name,\n",
    "                f\"{seed}\",\n",
    "                eval_filename,\n",
    "            ),\n",
    "            \"rb\",\n",
    "        ) as handle:\n",
    "            outputs = pickle.load(handle)\n",
    "\n",
    "        insertion_auc_dict = get_mean_aucs(outputs, \"insertion\")\n",
    "        deletion_auc_dict = get_mean_aucs(outputs, \"deletion\")\n",
    "        for eval_name in eval_name_list:\n",
    "            all_insertion_auc_dict[eval_name].append(insertion_auc_dict[eval_name])\n",
    "            all_deletion_auc_dict[eval_name].append(deletion_auc_dict[eval_name])\n",
    "\n",
    "    insertion_auc_stats_dict = {eval_name: {} for eval_name in eval_name_list}\n",
    "    deletion_auc_stats_dict = {eval_name: {} for eval_name in eval_name_list}\n",
    "    for eval_name in eval_name_list:\n",
    "        insertion_auc_stats_dict[eval_name][\"mean\"] = np.mean(all_insertion_auc_dict[eval_name])\n",
    "        insertion_auc_stats_dict[eval_name][\"std\"] = np.std(all_insertion_auc_dict[eval_name])\n",
    "        insertion_auc_stats_dict[eval_name][\"ci\"] = (\n",
    "            1.96 * insertion_auc_stats_dict[eval_name][\"std\"] / np.sqrt(len(seed_list))\n",
    "        )\n",
    "        deletion_auc_stats_dict[eval_name][\"mean\"] = np.mean(all_deletion_auc_dict[eval_name])\n",
    "        deletion_auc_stats_dict[eval_name][\"std\"] = np.std(all_deletion_auc_dict[eval_name])\n",
    "        deletion_auc_stats_dict[eval_name][\"ci\"] = (\n",
    "            1.96 * deletion_auc_stats_dict[eval_name][\"std\"] / np.sqrt(len(seed_list))\n",
    "        )\n",
    "    return insertion_auc_stats_dict, deletion_auc_stats_dict\n",
    "\n",
    "\n",
    "def format_eval_name(name):\n",
    "    format_map = {\n",
    "        \"corpus_cosine_similarity\": \"Cosine similarity to corpus\",\n",
    "        \"contrastive_corpus_cosine_similarity\": \"Contrastive corpus similarity\",\n",
    "        \"corpus_majority_prob\": \"Predicted probability of corpus majority prediction\",\n",
    "        \"explicand_pred_prob\": \"Predicted probability of explicand\",\n",
    "        \"explicand_rep_shift\": \"Representation shift of explicand\",\n",
    "    }\n",
    "    return format_map[name]\n",
    "\n",
    "\n",
    "def format_explanation_name(name):\n",
    "    format_map = {\n",
    "        \"self_weighted\": \"Labe-Free\",\n",
    "        \"contrastive_self_weighted\": \"Contrastive\",\n",
    "        \"corpus\": \"Corpus\",\n",
    "        \"contrastive_corpus\": \"COCOA\",\n",
    "    }\n",
    "    return format_map[name]\n",
    "\n",
    "\n",
    "def format_attribution_name(name):\n",
    "    format_map = {\n",
    "        \"gradient_shap\": \"Gradient Shap\",\n",
    "        \"int_grad\": \"Integrated Gradient\",\n",
    "        \"random_baseline\": \"Random\",\n",
    "        \"rise\": \"RISE\",\n",
    "    }\n",
    "    return format_map[name]\n",
    "\n",
    "\n",
    "def format_dataset_name(name):\n",
    "    format_map = {\n",
    "        \"imagenet\": \"Imagenet & SimCLR\",\n",
    "        \"cifar\": \"CIFAR-10 & SimSiam\",\n",
    "        \"mura\": \"MURA & ResNet\",\n",
    "    }\n",
    "    return format_map[name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9850f1dc-3306-4292-b1d3-894438224006",
   "metadata": {},
   "outputs": [],
   "source": [
    "explanation_colors = {\n",
    "    \"corpus\": \"tab:red\",\n",
    "    \"contrastive_corpus\": \"tab:blue\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d694eb2-3a56-47bd-b9e2-e8963aad8f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus_size_list = [1, 5, 20, 50, 100, 200]\n",
    "dataset_meta_dict = {\n",
    "    \"imagenet\": {\"encoder\": \"simclr_x1\", \"removal\": \"blurring\"},\n",
    "    \"cifar\": {\"encoder\": \"simsiam_18\", \"removal\": \"blurring\"},\n",
    "    \"mura\": {\"encoder\": \"classifier_18\", \"removal\": \"blurring\"},\n",
    "}\n",
    "explicand_size = 25\n",
    "blur_strength = 5.0\n",
    "superpixel_dim = 1\n",
    "eval_superpixel_dim = 1\n",
    "foil_size = 1500\n",
    "eval_foil_size = 1500\n",
    "take_attribution_abs = False\n",
    "normalize_similarity = True\n",
    "\n",
    "explanation_list = [\"corpus\", \"contrastive_corpus\"]\n",
    "attribution_list = [\"int_grad\", \"gradient_shap\", \"rise\"]\n",
    "eval_name_list = [\"contrastive_corpus_cosine_similarity\", \"corpus_majority_prob\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0484d5ee-f6f8-4989-a44c-13f2522546df",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lineplot_results(\n",
    "    dataset,\n",
    "    eval_name,\n",
    "    insertion_auc_dict,\n",
    "    deletion_auc_dict,\n",
    "    same_class,\n",
    "    fontsize=14,\n",
    "    plot_overall_title=True,\n",
    "    plot_legend=True,\n",
    "):\n",
    "    capsize = 5\n",
    "    fig, axes = plt.subplots(ncols=len(attribution_list), nrows=2, figsize=(10, 6))\n",
    "    formatted_dataset = format_dataset_name(dataset)\n",
    "    formatted_eval_name = format_eval_name(eval_name)\n",
    "    \n",
    "    \n",
    "    for explanation in explanation_list:\n",
    "        if same_class:\n",
    "            color = \"tab:blue\"\n",
    "        else:\n",
    "            color = \"tab:orange\"\n",
    "        linestyle = \"solid\"\n",
    "        label = format_explanation_name(explanation)\n",
    "        if explanation == \"corpus\":\n",
    "            available_attribution_list = [\"random_baseline\"] * len(attribution_list)\n",
    "            color = \"black\"\n",
    "            linestyle = \"dotted\"\n",
    "            label = \"Random\"\n",
    "        else:\n",
    "            available_attribution_list = attribution_list\n",
    "        \n",
    "        \n",
    "        for i, attribution in enumerate(available_attribution_list):\n",
    "            formatted_attribution = format_attribution_name(attribution)\n",
    "\n",
    "            insertion_auc_means = [\n",
    "                insertion_auc_dict[dataset][explanation][attribution][corpus_size][eval_name][\"mean\"]\n",
    "                for corpus_size in corpus_size_list\n",
    "            ]\n",
    "            insertion_auc_cis = [\n",
    "                insertion_auc_dict[dataset][explanation][attribution][corpus_size][eval_name][\"ci\"]\n",
    "                for corpus_size in corpus_size_list\n",
    "            ]\n",
    "            deletion_auc_means = [\n",
    "                deletion_auc_dict[dataset][explanation][attribution][corpus_size][eval_name][\"mean\"]\n",
    "                for corpus_size in corpus_size_list\n",
    "            ]\n",
    "            deletion_auc_cis = [\n",
    "                deletion_auc_dict[dataset][explanation][attribution][corpus_size][eval_name][\"ci\"]\n",
    "                for corpus_size in corpus_size_list\n",
    "            ]\n",
    "            \n",
    "            axes[0, i].errorbar(\n",
    "                corpus_size_list,\n",
    "                insertion_auc_means,\n",
    "                yerr=insertion_auc_cis,\n",
    "                color=color,\n",
    "                linestyle=linestyle,\n",
    "                marker=\"o\",\n",
    "                capsize=capsize,\n",
    "                label=label,\n",
    "            )\n",
    "            if i == 0:\n",
    "                axes[0, i].set_ylabel(f\"{formatted_dataset}\\ninsertion\", fontsize=fontsize)\n",
    "            axes[0, i].set_title(formatted_attribution, fontsize=fontsize)\n",
    "            axes[0, i].tick_params(axis=\"x\", labelsize=11)\n",
    "            axes[0, i].tick_params(axis=\"y\", labelsize=11)\n",
    "            \n",
    "            axes[1, i].errorbar(\n",
    "                corpus_size_list,\n",
    "                deletion_auc_means,\n",
    "                yerr=deletion_auc_cis,\n",
    "                color=color,\n",
    "                linestyle=linestyle,\n",
    "                marker=\"o\",\n",
    "                capsize=capsize,\n",
    "                label=label,\n",
    "            )\n",
    "            axes[1, i].set_xlabel(\"Corpus size\", fontsize=fontsize)\n",
    "            if i == 0:\n",
    "                axes[1, i].set_ylabel(f\"{formatted_dataset}\\ndeletion\", fontsize=fontsize)\n",
    "            axes[1, i].tick_params(axis=\"x\", labelsize=11)\n",
    "            axes[1, i].tick_params(axis=\"y\", labelsize=11)\n",
    "            \n",
    "\n",
    "    if plot_legend:\n",
    "        axes[0, 0].legend(loc=\"lower right\", fontsize=fontsize)\n",
    "    title = formatted_eval_name\n",
    "    if same_class:\n",
    "        title += \" (explicands are in the corpus class)\\n\"\n",
    "    else:\n",
    "        title += \" (explicands are not in the corpus class)\\n\"\n",
    "    if plot_overall_title:\n",
    "        plt.suptitle(title, fontsize=fontsize + 4)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "214e4a50-9f17-4bd3-8e88-5711b1b3bee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "insertion_auc_dict = {}\n",
    "deletion_auc_dict = {}\n",
    "\n",
    "for dataset, meta_dict in dataset_meta_dict.items():\n",
    "    insertion_auc_dict[dataset] = {}\n",
    "    deletion_auc_dict[dataset] = {}\n",
    "    encoder = meta_dict[\"encoder\"]\n",
    "    removal = meta_dict[\"removal\"]\n",
    "    \n",
    "    for explanation in explanation_list:\n",
    "        insertion_auc_dict[dataset][explanation] = {}\n",
    "        deletion_auc_dict[dataset][explanation] = {}\n",
    "        \n",
    "        if explanation == \"corpus\":\n",
    "            available_attribution_list = [\"random_baseline\"]\n",
    "        else:\n",
    "            available_attribution_list = attribution_list\n",
    "        \n",
    "        for attribution in available_attribution_list:\n",
    "            insertion_auc_dict[dataset][explanation][attribution] = {}\n",
    "            deletion_auc_dict[dataset][explanation][attribution] = {}\n",
    "            \n",
    "            for corpus_size in corpus_size_list:\n",
    "                insertion_stats_dict, deletion_stats_dict = get_auc_stats(\n",
    "                    dataset,\n",
    "                    encoder,\n",
    "                    explanation,\n",
    "                    attribution,\n",
    "                    eval_name_list,\n",
    "                    SEED_LIST,\n",
    "                    normalize_similarity,\n",
    "                    same_class=True,  # Make sure this is True.\n",
    "                    corpus_size=corpus_size,\n",
    "                    foil_size=foil_size,\n",
    "                    explicand_size=explicand_size,\n",
    "                    superpixel_dim=superpixel_dim,\n",
    "                    removal=removal,\n",
    "                    blur_strength=blur_strength,\n",
    "                    eval_superpixel_dim=eval_superpixel_dim,\n",
    "                    eval_foil_size=eval_foil_size,\n",
    "                    take_attribution_abs=take_attribution_abs,\n",
    "                    comprehensive=True,\n",
    "                )\n",
    "                insertion_auc_dict[dataset][explanation][attribution][corpus_size] = insertion_stats_dict\n",
    "                deletion_auc_dict[dataset][explanation][attribution][corpus_size] = deletion_stats_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27d9d741-2363-4c7b-9497-8ed9bc413d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, dataset in enumerate(dataset_meta_dict.keys()):\n",
    "    lineplot_results(\n",
    "        dataset=dataset,\n",
    "        eval_name=\"corpus_majority_prob\",\n",
    "        insertion_auc_dict=insertion_auc_dict,\n",
    "        deletion_auc_dict=deletion_auc_dict,\n",
    "        same_class=True,\n",
    "        fontsize=14,\n",
    "        plot_overall_title=True,\n",
    "        plot_legend=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59972dc9-a270-454c-a6f9-b7cd2188cbda",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, dataset in enumerate(dataset_meta_dict.keys()):\n",
    "    lineplot_results(\n",
    "        dataset=dataset,\n",
    "        eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "        insertion_auc_dict=insertion_auc_dict,\n",
    "        deletion_auc_dict=deletion_auc_dict,\n",
    "        same_class=True,\n",
    "        fontsize=14,\n",
    "        plot_overall_title=True,\n",
    "        plot_legend=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada5fdf4-7eee-4eb0-927a-c9e34a5f41c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "insertion_auc_dict = {}\n",
    "deletion_auc_dict = {}\n",
    "\n",
    "for dataset, meta_dict in dataset_meta_dict.items():\n",
    "    insertion_auc_dict[dataset] = {}\n",
    "    deletion_auc_dict[dataset] = {}\n",
    "    encoder = meta_dict[\"encoder\"]\n",
    "    removal = meta_dict[\"removal\"]\n",
    "    \n",
    "    for explanation in explanation_list:\n",
    "        insertion_auc_dict[dataset][explanation] = {}\n",
    "        deletion_auc_dict[dataset][explanation] = {}\n",
    "        \n",
    "        if explanation == \"corpus\":\n",
    "            available_attribution_list = [\"random_baseline\"]\n",
    "        else:\n",
    "            available_attribution_list = attribution_list\n",
    "        \n",
    "        for attribution in available_attribution_list:\n",
    "            insertion_auc_dict[dataset][explanation][attribution] = {}\n",
    "            deletion_auc_dict[dataset][explanation][attribution] = {}\n",
    "            \n",
    "            for corpus_size in corpus_size_list:\n",
    "                insertion_stats_dict, deletion_stats_dict = get_auc_stats(\n",
    "                    dataset,\n",
    "                    encoder,\n",
    "                    explanation,\n",
    "                    attribution,\n",
    "                    eval_name_list,\n",
    "                    SEED_LIST,\n",
    "                    normalize_similarity,\n",
    "                    same_class=False,  # Make sure this is False.\n",
    "                    corpus_size=corpus_size,\n",
    "                    foil_size=foil_size,\n",
    "                    explicand_size=explicand_size,\n",
    "                    superpixel_dim=superpixel_dim,\n",
    "                    removal=removal,\n",
    "                    blur_strength=blur_strength,\n",
    "                    eval_superpixel_dim=eval_superpixel_dim,\n",
    "                    eval_foil_size=eval_foil_size,\n",
    "                    take_attribution_abs=take_attribution_abs,\n",
    "                    comprehensive=True,\n",
    "                )\n",
    "                insertion_auc_dict[dataset][explanation][attribution][corpus_size] = insertion_stats_dict\n",
    "                deletion_auc_dict[dataset][explanation][attribution][corpus_size] = deletion_stats_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfd6c3a-2a80-4491-abcb-3fc397ee6789",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, dataset in enumerate(dataset_meta_dict.keys()):\n",
    "    lineplot_results(\n",
    "        dataset=dataset,\n",
    "        eval_name=\"corpus_majority_prob\",\n",
    "        insertion_auc_dict=insertion_auc_dict,\n",
    "        deletion_auc_dict=deletion_auc_dict,\n",
    "        same_class=False,\n",
    "        fontsize=14,\n",
    "        plot_overall_title=True,\n",
    "        plot_legend=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb41f68e-9c58-4d27-af80-da5d3c38d930",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, dataset in enumerate(dataset_meta_dict.keys()):\n",
    "    lineplot_results(\n",
    "        dataset=dataset,\n",
    "        eval_name=\"contrastive_corpus_cosine_similarity\",\n",
    "        insertion_auc_dict=insertion_auc_dict,\n",
    "        deletion_auc_dict=deletion_auc_dict,\n",
    "        same_class=False,\n",
    "        fontsize=14,\n",
    "        plot_overall_title=True,\n",
    "        plot_legend=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c14156-f40b-4692-b45b-041d7b5ffa5a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
