{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfafc2ce-df1a-42c1-b85d-863b93d510bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "file_path = \"./tables/full_mis_rocauc_only_risks.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a15810-8f13-4520-b95f-e15e12203eb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "mis_data = pd.read_csv(file_path)\n",
    "mis_data = mis_data[~mis_data.UQMetric.str.endswith(\"Inner Inner\")]\n",
    "# mis_data = mis_data[mis_data.base_rule != 'Neglog']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83edbcdc-6520-49c1-b334-b869f0bbb193",
   "metadata": {},
   "outputs": [],
   "source": [
    "roc_auc_summary = mis_data[\"RocAucScore\"].describe()\n",
    "\n",
    "# Plotting the distribution of RocAucScores\n",
    "plt.figure(figsize=(10, 6))\n",
    "sns.histplot(mis_data[\"RocAucScore\"], bins=30, kde=True)\n",
    "plt.title(\"Distribution of ROC AUC Scores\")\n",
    "plt.xlabel(\"ROC AUC Score\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32b54be-31a6-45ff-a4eb-c7c1ae553849",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(14, 8))\n",
    "sns.boxplot(data=mis_data, x=\"UQMetric\", y=\"RocAucScore\", hue=\"UQMetric\", legend=False)\n",
    "plt.xticks(rotation=45, ha=\"right\")\n",
    "plt.title(\"Distribution of RocAucScore Across Different Uncertainty Metrics\")\n",
    "plt.xlabel(\"Uncertainty Metric (UQMetric)\")\n",
    "plt.ylabel(\"ROC AUC Score\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0065f1c5-410a-42e8-a68f-a712ed976f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(18, 10))\n",
    "\n",
    "# Plotting the distribution of RocAucScore for each architecture and loss function combination\n",
    "sns.catplot(\n",
    "    x=\"LossFunction\",\n",
    "    y=\"RocAucScore\",\n",
    "    hue=\"architecture\",\n",
    "    kind=\"box\",\n",
    "    data=mis_data,\n",
    "    height=6,\n",
    "    aspect=2,\n",
    "    palette=\"muted\",\n",
    ")\n",
    "plt.title(\"ROC AUC Score Distribution by Architecture and Loss Function\")\n",
    "plt.xlabel(\"Loss Function\")\n",
    "plt.ylabel(\"ROC AUC Score\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f37429cf-1749-479e-86cf-74ff3e1aef90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grouping data to compare the impact of loss functions on ROC AUC scores\n",
    "loss_function_comparison = (\n",
    "    mis_data.groupby([\"LossFunction\", \"training_dataset\"])[\"RocAucScore\"]\n",
    "    .mean()\n",
    "    .unstack()\n",
    ")\n",
    "\n",
    "# Plotting the impact of loss functions on ROC AUC scores\n",
    "plt.figure(figsize=(14, 8))\n",
    "sns.heatmap(\n",
    "    loss_function_comparison, annot=True, fmt=\".2f\", cmap=\"coolwarm\", linewidths=0.5\n",
    ")\n",
    "plt.title(\"Average ROC AUC Score by Loss Functions Across Datasets\")\n",
    "plt.xlabel(\"Dataset\")\n",
    "plt.ylabel(\"Loss Function\")\n",
    "plt.xticks(rotation=45)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b2f4da-cb78-4a05-9f98-c9a8b218054b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grouping data to compare the performance of different uncertainty metrics across datasets\n",
    "performance_comparison = (\n",
    "    mis_data.groupby([\"UQMetric\", \"training_dataset\"])[\"RocAucScore\"].mean().unstack()\n",
    ")\n",
    "\n",
    "# Plotting the performance of different uncertainty metrics across datasets\n",
    "plt.figure(figsize=(14, 8))\n",
    "sns.heatmap(\n",
    "    performance_comparison, annot=True, fmt=\".2f\", cmap=\"viridis\", linewidths=0.5\n",
    ")\n",
    "plt.title(\"Average ROC AUC Score for Uncertainty Metrics Across Datasets\")\n",
    "plt.xlabel(\"Dataset\")\n",
    "plt.ylabel(\"Uncertainty Metric\")\n",
    "plt.xticks(rotation=45)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12636f19-5109-4984-bf42-ee3f3b7d20d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "# Load the data\n",
    "df = mis_data.copy()\n",
    "same_loss_base = df[df[\"LossFunction\"] == df[\"base_rule\"]]\n",
    "diff_loss_base = df[df[\"LossFunction\"] != df[\"base_rule\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410592f8-f038-4f01-946e-13496618e287",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate summary statistics\n",
    "summary_same = same_loss_base[\"RocAucScore\"].describe()\n",
    "summary_diff = diff_loss_base[\"RocAucScore\"].describe()\n",
    "\n",
    "# Print summary statistics in a table\n",
    "summary_table = pd.DataFrame(\n",
    "    {\n",
    "        \"Same LossFunction and base_rule\": summary_same,\n",
    "        \"Different LossFunction and base_rule\": summary_diff,\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7bc8f34-7857-402e-8895-dbbd37223c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distributions\n",
    "plt.figure(figsize=(14, 7))\n",
    "\n",
    "values = df[\"RocAucScore\"].values\n",
    "bins_ = np.linspace(values.min(), values.max(), 50)\n",
    "\n",
    "# Distribution plot for same values\n",
    "plt.subplot(1, 2, 1)\n",
    "sns.histplot(same_loss_base[\"RocAucScore\"], kde=True, color=\"blue\", bins=bins_)\n",
    "plt.title(\"RocAucScore Distribution (Same LossFunction and base_rule)\")\n",
    "plt.xlabel(\"RocAucScore\")\n",
    "\n",
    "# Distribution plot for different values\n",
    "plt.subplot(1, 2, 2)\n",
    "sns.histplot(diff_loss_base[\"RocAucScore\"], kde=True, color=\"red\", bins=bins_)\n",
    "plt.title(\"RocAucScore Distribution (Different LossFunction and base_rule)\")\n",
    "plt.xlabel(\"RocAucScore\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "summary_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09fd9c1-5127-4fd0-827c-88fa87a3b0f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get unique LossFunction values\n",
    "loss_functions = df[\"LossFunction\"].unique()\n",
    "\n",
    "# Create a figure with subplots for each LossFunction\n",
    "n = len(loss_functions)\n",
    "fig, axes = plt.subplots(n, 2, figsize=(14, n * 3), sharex=True, sharey=True)\n",
    "fig.suptitle(\"RocAucScore Distribution for Each LossFunction\")\n",
    "\n",
    "# Iterate through each LossFunction value\n",
    "for i, loss_function in enumerate(loss_functions):\n",
    "    same_loss_base = df[\n",
    "        (df[\"LossFunction\"] == loss_function) & (df[\"LossFunction\"] == df[\"base_rule\"])\n",
    "    ]\n",
    "    diff_loss_base = df[\n",
    "        (df[\"LossFunction\"] == loss_function) & (df[\"LossFunction\"] != df[\"base_rule\"])\n",
    "    ]\n",
    "\n",
    "    # Plot distributions\n",
    "    sns.histplot(\n",
    "        same_loss_base[\"RocAucScore\"], kde=True, color=\"blue\", ax=axes[i, 0], bins=bins_\n",
    "    )\n",
    "    axes[i, 0].set_title(f\"{loss_function} (Same)\")\n",
    "    axes[i, 0].set_xlabel(\"RocAucScore\")\n",
    "\n",
    "    sns.histplot(\n",
    "        diff_loss_base[\"RocAucScore\"], kde=True, color=\"red\", ax=axes[i, 1], bins=bins_\n",
    "    )\n",
    "    axes[i, 1].set_title(f\"{loss_function} (Different)\")\n",
    "    axes[i, 1].set_xlabel(\"RocAucScore\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc9cb15c-597f-4ef2-b709-83d31a73c38e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new column to indicate if LossFunction and base_rule are the same\n",
    "df[\"Same\"] = df[\"LossFunction\"] == df[\"base_rule\"]\n",
    "\n",
    "# Set the plot size\n",
    "plt.figure(figsize=(14, 10))\n",
    "\n",
    "# Create boxplots for each LossFunction\n",
    "sns.boxplot(\n",
    "    x=\"LossFunction\", y=\"RocAucScore\", hue=\"Same\", data=df, palette=[\"red\", \"blue\"]\n",
    ")\n",
    "\n",
    "# Customize the plot\n",
    "plt.title(\"RocAucScore Distribution for Each LossFunction\")\n",
    "plt.xlabel(\"LossFunction\")\n",
    "plt.ylabel(\"RocAucScore\")\n",
    "plt.legend(title=\"Same LossFunction and base_rule\", loc=\"upper right\")\n",
    "plt.xticks(rotation=45)\n",
    "plt.tight_layout()\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4597820c-cc5c-42d5-bd4e-21c2803aea98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get unique base_rule values\n",
    "base_rules = df[\"base_rule\"].unique()\n",
    "\n",
    "# Create a figure with subplots for each base_rule\n",
    "n = len(base_rules)\n",
    "fig, axes = plt.subplots(n, 2, figsize=(10, n * 3), sharex=True, sharey=True)\n",
    "fig.suptitle(\"RocAucScore Distribution for Each base_rule\")\n",
    "\n",
    "# Iterate through each base_rule value\n",
    "for i, base_rule in enumerate(base_rules):\n",
    "    same_loss_base = df[\n",
    "        (df[\"base_rule\"] == base_rule) & (df[\"LossFunction\"] == df[\"base_rule\"])\n",
    "    ]\n",
    "    diff_loss_base = df[\n",
    "        (df[\"base_rule\"] == base_rule) & (df[\"LossFunction\"] != df[\"base_rule\"])\n",
    "    ]\n",
    "\n",
    "    # Plot distributions\n",
    "    sns.histplot(\n",
    "        same_loss_base[\"RocAucScore\"], kde=True, color=\"blue\", ax=axes[i, 0], bins=bins_\n",
    "    )\n",
    "    axes[i, 0].set_title(f\"{base_rule} (Same)\")\n",
    "    axes[i, 0].set_xlabel(\"RocAucScore\")\n",
    "\n",
    "    sns.histplot(\n",
    "        diff_loss_base[\"RocAucScore\"], kde=True, color=\"red\", ax=axes[i, 1], bins=bins_\n",
    "    )\n",
    "    axes[i, 1].set_title(f\"{base_rule} (Different)\")\n",
    "    axes[i, 1].set_xlabel(\"RocAucScore\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06dae4a7-1154-4eee-acf8-8ec2922d59b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new column to indicate if LossFunction and base_rule are the same\n",
    "df[\"Same\"] = df[\"LossFunction\"] == df[\"base_rule\"]\n",
    "df[\"Same\"] = df[\"Same\"].replace(\n",
    "    {\n",
    "        False: \"Different\",\n",
    "        True: \"Same\",\n",
    "    }\n",
    ")\n",
    "\n",
    "df[\"base_rule\"] = df[\"base_rule\"].replace(\n",
    "    {\n",
    "        \"Maxprob\": \"Zero-one\",\n",
    "    }\n",
    ")\n",
    "\n",
    "pretty_matplotlib_config(fontsize=50)\n",
    "sns.set(font_scale=5)  # crazy big\n",
    "# Set the plot size\n",
    "plt.figure(figsize=(16, 12), dpi=150)\n",
    "\n",
    "# Create boxplots for each base_rule\n",
    "sns.boxplot(\n",
    "    x=\"base_rule\", y=\"RocAucScore\", hue=\"Same\", data=df, palette=[\"blue\", \"red\"]\n",
    ")\n",
    "\n",
    "plt.title(\"AUROC stratified by plug-in\")\n",
    "plt.xlabel(\"Plug-in function\")\n",
    "plt.ylabel(\"AUROC\")\n",
    "plt.legend(title=\"Same LossFunction and base_rule\", loc=\"upper right\")\n",
    "plt.xticks(rotation=45)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.legend(prop={\"size\": 30}, loc=4)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(\"imgs/mis_auroc_stratified_by_plugin.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e25d0c0-4a15-4bc0-96c9-1cbfdff72bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get unique LossFunction and base_rule values\n",
    "loss_functions = df[\"LossFunction\"].unique()\n",
    "base_rules = df[\"base_rule\"].unique()\n",
    "\n",
    "# Create a figure with subplots for each combination of LossFunction and base_rule\n",
    "fig, axes = plt.subplots(\n",
    "    len(base_rules), len(loss_functions), figsize=(20, 15), sharex=True, sharey=True\n",
    ")\n",
    "fig.suptitle(\"RocAucScore Distribution by LossFunction and base_rule\")\n",
    "\n",
    "# Iterate through each combination of LossFunction and base_rule\n",
    "for i, base_rule in enumerate(base_rules):\n",
    "    for j, loss_function in enumerate(loss_functions):\n",
    "        subset = df[\n",
    "            (df[\"base_rule\"] == base_rule) & (df[\"LossFunction\"] == loss_function)\n",
    "        ]\n",
    "        same_subset = subset[subset[\"Same\"]]\n",
    "        diff_subset = subset[~subset[\"Same\"]]\n",
    "\n",
    "        if not same_subset.empty:\n",
    "            sns.histplot(\n",
    "                same_subset[\"RocAucScore\"],\n",
    "                kde=True,\n",
    "                color=\"blue\",\n",
    "                ax=axes[i, j],\n",
    "                label=\"Same\",\n",
    "                bins=bins_,\n",
    "            )\n",
    "        if not diff_subset.empty:\n",
    "            sns.histplot(\n",
    "                diff_subset[\"RocAucScore\"],\n",
    "                kde=True,\n",
    "                color=\"red\",\n",
    "                ax=axes[i, j],\n",
    "                label=\"Different\",\n",
    "                bins=bins_,\n",
    "            )\n",
    "\n",
    "        axes[i, j].set_title(f\"{base_rule} - {loss_function}\")\n",
    "        axes[i, j].set_xlabel(\"RocAucScore\")\n",
    "        axes[i, j].set_ylabel(\"\")\n",
    "        axes[i, j].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
