{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%aimport -sklearn, matplotlib, numpy, seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "# import sklearn\n",
    "# from sklearn.datasets import fetch_openml\n",
    "# from sklearn import preprocessing\n",
    "# from sklearn.model_selection import train_test_split\n",
    "# from sklearn.linear_model import LogisticRegression\n",
    "# from sklearn.ensemble import RandomForestClassifier\n",
    "# from sklearn.ensemble import GradientBoostingClassifier\n",
    "import seaborn as sns\n",
    "# from utils import binary_assessment\n",
    "# from utils import binary_calibration\n",
    "# import lightgbm as lgb\n",
    "# import toplabel_confidence\n",
    "import classwise\n",
    "import toplabel\n",
    "import toplabel_conditional\n",
    "# import conditional\n",
    "#sns.set_theme(style=\"ticks\", color_codes=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(style=\"whitegrid\",\n",
    "        font_scale=1.5,\n",
    "        rc={\n",
    "            \"lines.linewidth\": 4,\n",
    "            \"axes.facecolor\": \"1.0\",\n",
    "            'figure.figsize': (6, 6)\n",
    "        })\n",
    "sns.set_style({'font.family':'serif'})\n",
    "sns.reset_orig()\n",
    "#sns.set_palette(\"husl\", 9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "title_dict = {'wide_resnet': 'Wide-ResNet-26-10', 'resnet50': 'ResNet-50', 'resnet110': 'ResNet-110'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Top-label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_top_label(datafile_prefix, output_figure_prefix):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 5\n",
    "        bin_upper = 25\n",
    "        A, B, C, D, E, F = toplabel.toplabel(datafile_prefix.format(u), bin_lower, bin_upper, True)\n",
    "        X = np.arange(bin_lower, bin_upper + 1)\n",
    "\n",
    "        sns.set_context(\"notebook\", font_scale=1.5, rc={\"lines.linewidth\": 2.5})\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "        ax.plot(X, A, '-', color=sns.color_palette()[0], label='ResNet110 top-label-ECE')\n",
    "        ax.plot(X, B, '-', color=sns.color_palette()[1], label='TS top-label-ECE')\n",
    "        ax.plot(X, C, '-', color=sns.color_palette()[2], label='HB top-label-ECE')\n",
    "        ax.plot(X, F, '-', color=sns.color_palette()[3], label='Normalized HB top-label-ECE')\n",
    "        ax.plot(X, D, '-', color=sns.color_palette()[4], label='VS top-label-ECE')\n",
    "        ax.plot(X, E, '-', color=sns.color_palette()[5], label='DS top-label-ECE')\n",
    "        ax.set_xticks(np.linspace(5,25,5))\n",
    "        #ax.set_yticks(np.linspace(0.005,0.03,7))\n",
    "        #ax.set_ylim([0,0.2])\n",
    "        ax.set_xlabel('Number of bins')\n",
    "        ax.set_ylabel('Estimated ECE')\n",
    "        #ax.legend(loc=(1.01,0))\n",
    "        ax.grid(True, linestyle='--')\n",
    "        ax.set_title(title_dict[u])\n",
    "        plt.savefig(output_figure_prefix.format(u), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/brier_toplabel/{}.pdf\"\n",
    "plot_top_label(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/focal_toplabel/{}.pdf\"\n",
    "plot_top_label(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/brier_toplabel/{}.pdf\"\n",
    "plot_top_label(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/focal_toplabel/{}.pdf\"\n",
    "plot_top_label(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Top-label conditional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_top_label_conditional(datafile_prefix, output_figure_prefix):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 5\n",
    "        bin_upper = 5\n",
    "        A, B, C, D, E, F = toplabel_conditional.toplabel_conditional(datafile_prefix.format(u), bin_lower, bin_upper, True)\n",
    "        X = np.arange(bin_lower, bin_upper + 1)\n",
    "\n",
    "        sns.set_context(\"notebook\", font_scale=1.5, rc={\"lines.linewidth\": 2.5})\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "        ax.plot(X, A, '-', color=sns.color_palette()[0], label='Base model')\n",
    "        ax.plot(X, B, '-', color=sns.color_palette()[1], label='Temperature scaling')\n",
    "        ax.plot(X, C, '-', color=sns.color_palette()[2], label='Class-wise-HB')\n",
    "        ax.plot(X, F, '-', color=sns.color_palette()[3], label='Normalized-HB')\n",
    "        ax.plot(X, D, '-', color=sns.color_palette()[4], label='Vector scaling')\n",
    "        ax.plot(X, E, '-', color=sns.color_palette()[5], label='Dirichlet scaling')\n",
    "        ax.set_xticks(np.linspace(5,25,5))\n",
    "        #ax.set_yticks(np.linspace(0.005,0.03,7))\n",
    "        #ax.set_ylim([0,0.2])\n",
    "        ax.set_xlabel('Number of bins')\n",
    "        ax.set_ylabel('Estimated ECE')\n",
    "        #ax.legend(loc=(1.01,0),ncol=3)\n",
    "        ax.grid(True, linestyle='--')\n",
    "        ax.set_title(title_dict[u])\n",
    "        plt.savefig(output_figure_prefix.format(u), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/brier_conditional/{}.pdf\"\n",
    "plot_top_label_conditional(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/focal_conditional/{}.pdf\"\n",
    "plot_top_label_conditional(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/brier_conditional/{}.pdf\"\n",
    "plot_top_label_conditional(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/focal_conditional/{}.pdf\"\n",
    "plot_top_label_conditional(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_classwise(datafile_prefix, output_figure_prefix):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 5\n",
    "        bin_upper = 25\n",
    "        A, B, C, D, E, F = classwise.classwise_ece(datafile_prefix.format(u), bin_lower, bin_upper)\n",
    "        X = np.arange(bin_lower, bin_upper + 1)\n",
    "\n",
    "        sns.set_context(\"notebook\", font_scale=1.5, rc={\"lines.linewidth\": 2.5})\n",
    "\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(6,6))\n",
    "        ax.plot(X, A, '-', color=sns.color_palette()[0], label='ResNet110 top-label-ECE')\n",
    "        ax.plot(X, B, '-', color=sns.color_palette()[1], label='TS top-label-ECE')\n",
    "        ax.plot(X, C, '-', color=sns.color_palette()[2], label='HB top-label-ECE')\n",
    "        ax.plot(X, D, '-', color=sns.color_palette()[3], label='Normalized HB top-label-ECE')\n",
    "        ax.plot(X, E, '-', color=sns.color_palette()[4], label='VS top-label-ECE')\n",
    "        ax.plot(X, F, '-', color=sns.color_palette()[5], label='DS top-label-ECE')\n",
    "        ax.set_xticks(np.linspace(5,25,5))\n",
    "        #ax.set_yticks(np.linspace(0.005,0.03,7))\n",
    "        #ax.set_ylim([0,0.2])\n",
    "        ax.set_xlabel('Number of bins')\n",
    "        ax.set_ylabel('Estimated ECE')\n",
    "        #ax.legend(loc=(1.01,0))\n",
    "        ax.grid(True, linestyle='--')\n",
    "        ax.set_title(title_dict[u])\n",
    "        plt.savefig(output_figure_prefix.format(u), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/brier_classwise/{}.pdf\"\n",
    "plot_classwise(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar10/focal_classwise/{}.pdf\"\n",
    "plot_classwise(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/brier_classwise/{}.pdf\"\n",
    "plot_classwise(datafile_prefix, output_figure_prefix)\n",
    "\n",
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "output_figure_prefix = \"iclr_final/cifar100/focal_classwise/{}.pdf\"\n",
    "plot_classwise(datafile_prefix, output_figure_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def table_top_label(datafile_prefix, tag):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 15\n",
    "        bin_upper = 15\n",
    "        A, B, C, D, E, F = toplabel.toplabel(datafile_prefix.format(u), bin_lower, bin_upper, True)\n",
    "        print(tag, u)\n",
    "        print(\" {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f}\".format(A[0], B[0], C[0], D[0], E[0], F[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "tag = \"cifar10+brier+toplabel\"\n",
    "table_top_label(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "tag = \"cifar10+focal+toplabel\"\n",
    "table_top_label(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "tag = \"cifar100+brier+toplabel\"\n",
    "table_top_label(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "tag = \"cifar100+focal+toplabel\"\n",
    "table_top_label(datafile_prefix, tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def table_class_wise(datafile_prefix, tag):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 15\n",
    "        bin_upper = 15\n",
    "        A, B, C, D, E, F = classwise.classwise_ece(datafile_prefix.format(u), bin_lower, bin_upper)\n",
    "        print(tag, u)\n",
    "        print(\" {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f}\".format(A[0], B[0], C[0], D[0], E[0], F[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "tag = \"cifar10+brier+classwise\"\n",
    "table_class_wise(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "tag = \"cifar10+focal+classwise\"\n",
    "table_class_wise(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "tag = \"cifar100+brier+classwise\"\n",
    "table_class_wise(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "tag = \"cifar100+focal+classwise\"\n",
    "table_class_wise(datafile_prefix, tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def table_top_label_conditional(datafile_prefix, tag):\n",
    "    for u in title_dict: \n",
    "        bin_lower = 15\n",
    "        bin_upper = 15\n",
    "        A, B, C, D, E, F = toplabel_conditional.toplabel_conditional(datafile_prefix.format(u), bin_lower, bin_upper, True)\n",
    "        print(tag, u)\n",
    "        print(\" {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f} & {:.3f}\".format(A[0], B[0], C[0], D[0], E[0], F[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datafile_prefix = \"data/cifar10_{}\"\n",
    "tag = \"cifar10+brier+toplabel-conditional\"\n",
    "table_top_label_conditional(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar10_{}\"\n",
    "tag = \"cifar10+focal+toplabel-conditional\"\n",
    "table_top_label_conditional(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/cifar100_{}\"\n",
    "tag = \"cifar100+brier+toplabel-conditional\"\n",
    "table_top_label_conditional(datafile_prefix, tag)\n",
    "datafile_prefix = \"data/focal_cifar100_{}\"\n",
    "tag = \"cifar100+focal+toplabel-conditional\"\n",
    "table_top_label_conditional(datafile_prefix, tag)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
