{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import to_rgba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_map = {'auto_label_opt_v0': '#A12230',\n",
    "            'dirichlet': '#FB75AF',\n",
    "            'histogram_binning_top_label': '#4F7152',\n",
    "            'scaling_binning': '#306689',\n",
    "            'temp_scaling': '#88548E',\n",
    "            'None': '#7BC276'}\n",
    "\n",
    "marker_map = {'auto_label_opt_v0': '*',\n",
    "            'dirichlet': 'o',\n",
    "            'histogram_binning_top_label': '^',\n",
    "            'scaling_binning': 's',\n",
    "            'temp_scaling': 'd',\n",
    "            'None': 'P'}\n",
    "\n",
    "legend_map = {'auto_label_opt_v0': \"Ours\",\n",
    "            'dirichlet': 'Dirichlet',\n",
    "            'histogram_binning_top_label': 'Top-HB',\n",
    "            'scaling_binning': 'SB',\n",
    "            'temp_scaling': 'TS',\n",
    "            'None': 'Softmax'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize=18\n",
    "markersize=12\n",
    "linewidth=2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ablation(ax, X_axis, Y_axis, std_axis):\n",
    "    df_sorted = df.sort_values(by=[X_axis])\n",
    "    for calib_conf, group in df_sorted.groupby('calib_conf'):\n",
    "        # if calib_conf == 'auto_label_opt_v0' and X_axis == 'N_t':\n",
    "        #     index_5000 = group[group[f'{X_axis}'] == 5000].index[0]\n",
    "        #     ax.errorbar(group[f'{X_axis}'][index_5000], group[f'{Y_axis}'][index_5000], \n",
    "        #                 yerr=group[f'{std_axis}'][index_5000], \n",
    "        #                 label=legend_map[calib_conf], marker=marker_map[calib_conf], \n",
    "        #                 linestyle='None', color=color_map[calib_conf],\n",
    "        #                 markersize=markersize, linewidth=linewidth)\n",
    "        #     # ax.text(group[f'{X_axis}'][index_5000], group[f'{Y_axis}'][index_5000] + 2, \n",
    "        #     #         f'{group[f\"{X_axis}\"][index_5000]}', ha='center', va='bottom',\n",
    "        #     #          color=color_map[calib_conf])\n",
    "        # elif calib_conf == 'auto_label_opt_v0' and X_axis in ['N_v', 'calib_val_frac']:\n",
    "        #     # Error bar plot!\n",
    "        #     #ax.errorbar(group[f'{X_axis}'], group[f'{Y_axis}'], yerr=group[f'{std_axis}'], label=calib_conf, marker='o')\n",
    "        #     ax.plot(group[f'{X_axis}'], group[f'{Y_axis}'], label=legend_map[calib_conf], \n",
    "        #             marker=marker_map[calib_conf], color=color_map[calib_conf],\n",
    "        #             markersize=markersize, linewidth=linewidth)\n",
    "        #     ax.fill_between(group[f'{X_axis}'], group[f'{Y_axis}'] - group[f'{std_axis}'], \n",
    "        #                     group[f'{Y_axis}'] + group[f'{std_axis}'], alpha=0.2, \n",
    "        #                     color=color_map[calib_conf])\n",
    "        # else:\n",
    "        ax.plot(group[f'{X_axis}'], group[f'{Y_axis}'], label=legend_map[calib_conf], \n",
    "                marker=marker_map[calib_conf], color=color_map[calib_conf],\n",
    "                markersize=markersize, linewidth=linewidth)\n",
    "        ax.fill_between(group[f'{X_axis}'], group[f'{Y_axis}'] - group[f'{std_axis}'], \n",
    "                        group[f'{Y_axis}'] + group[f'{std_axis}'], alpha=0.2, \n",
    "                        color=color_map[calib_conf])\n",
    "        ax.set_xticks(group[f'{X_axis}'][::1])\n",
    "        # ax.set_xticklabels(group[f'{X_axis}'][::1], rotation=45, ha='right')\n",
    "        \n",
    "        ax.set_xlabel(f'{X_axis}', fontsize=fontsize)\n",
    "        ax.set_ylabel(f'{Y_axis}', fontsize=fontsize)\n",
    "        ax.set_title(f'{Y_axis} vs {X_axis}', fontsize=fontsize)\n",
    "        # ax.legend(fontsize=fontsize, loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=len(df['calib_conf'].unique()))\n",
    "        ax.grid(True, linestyle='--', alpha=0.7)  \n",
    "        # ax.set_facecolor('lightgray')\n",
    "\n",
    "        for spine in ax.spines.values():\n",
    "            spine.set_linewidth(2)\n",
    "        \n",
    "        # ax.xticks(fontsize=fontsize)\n",
    "        # ax.yticks(fontsize=fontsize)\n",
    "        ax.tick_params(axis='both', which='major', labelsize=fontsize)\n",
    "\n",
    "        # plt.savefig(f'{X_axis}_{Y_axis}.png', dpi=1600, bbox_inches='tight')\n",
    "        # plt.savefig(f'{X_axis}_{Y_axis}.pdf', dpi=1600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"cifar10_tbal_squentropy_Nt.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "\n",
    "# tab20 color map\n",
    "tab20_colors = plt.cm.get_cmap('tab10').colors\n",
    "color_map = {calib_conf: to_rgba(tab20_colors[i % len(tab20_colors)]) \n",
    "            for i, calib_conf in enumerate(df['calib_conf'].unique())}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "plot_ablation(ax, 'N_t', \"Coverage-Mean\", \"Coverage-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "plot_ablation(ax, 'N_t', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"cifar10_tbal_squentropy_Nt.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "plot_ablation(ax1, 'N_t', \"Coverage-Mean\", \"Coverage-Std\")\n",
    "plot_ablation(ax2, 'N_t', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")\n",
    "\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "fig.legend(handles, labels, fontsize=fontsize, loc='lower center', bbox_to_anchor=(0.5, -0.15), ncol=len(df['calib_conf'].unique()))\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig('Nt.png', dpi=1600, bbox_inches='tight')\n",
    "plt.savefig('Nt.pdf', dpi=1600, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calib val frac"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"cifar10_tbal_squentropy_calib_val_frac.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ablation('calib_val_frac', \"Coverage-Mean\", \"Coverage-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ablation('calib_val_frac', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"calib_val_frac_additional_exp.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "plot_ablation(ax1, 'calib_val_frac', \"Coverage-Mean\", \"Coverage-Std\")\n",
    "plot_ablation(ax2, 'calib_val_frac', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")\n",
    "\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "fig.legend(handles, labels, fontsize=fontsize, loc='lower center', bbox_to_anchor=(0.5, -0.15), ncol=len(df['calib_conf'].unique()))\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig('Nt.png', dpi=1600, bbox_inches='tight')\n",
    "plt.savefig('calib_val_frac.pdf', dpi=1600, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "N_v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"cifar10_tbal_squentropy_Nv.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "df.columns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ablation('N_v', \"Coverage-Mean\", \"Coverage-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ablation('N_v', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"cifar10_tbal_squentropy_Nv.xlsx\"\n",
    "\n",
    "df = pd.read_excel(filename)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "plot_ablation(ax1, 'N_v', \"Coverage-Mean\", \"Coverage-Std\")\n",
    "plot_ablation(ax2, 'N_v', \"Auto-Labeling-Err-Mean\", \"Auto-Labeling-Err-Std\")\n",
    "\n",
    "handles, labels = ax1.get_legend_handles_labels()\n",
    "fig.legend(handles, labels, fontsize=fontsize, loc='lower center', bbox_to_anchor=(0.5, -0.15), ncol=len(df['calib_conf'].unique()))\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('N_v.pdf', dpi=1600, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
