{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3501b06d-e918-46c7-b20d-220f65e9245a",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "838991a0-8f2c-4b05-b30b-212de9e35d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import relplot as rp\n",
    "import timm\n",
    "import torch\n",
    "\n",
    "from datasets import load_dataset\n",
    "from huggingface_hub import login\n",
    "\n",
    "from utils.calibration import SharpCal\n",
    "from utils.recal import TemperatureScaler, VectorScaler, MatrixScaler\n",
    "from netcal.binning import HistogramBinning, IsotonicRegression\n",
    "from netcal.metrics import ACE, ECE\n",
    "from sklearn.metrics import roc_auc_score, roc_curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "41099512-cbda-49d8-b79c-f705cb743e8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device count:  1\n",
      "GPU being used: Tesla V100-PCIE-32GB\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "if device != \"cpu\":\n",
    "    print(\"Device count: \", torch.cuda.device_count())\n",
    "    print(\"GPU being used: {}\".format(torch.cuda.get_device_name(0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1f6ec189-3d99-4298-b43b-16c1339fb3a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "447ec673-6c87-4e41-939d-c73b7dbaf606",
   "metadata": {},
   "source": [
    "## Model and Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ca723ac1-7449-4bf7-a757-529ff0cc3959",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n",
      "Token is valid (permission: read).\n",
      "Your token has been saved to /home/users/mc696/.cache/huggingface/token\n",
      "Login successful\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34a86f4b0c794e1b884919dcd15bda8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2606d23ece9b46569f5ce526f2159c46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "login(token=os.environ[\"HF_READ\"])\n",
    "dataset = load_dataset(\"timm/imagenet-1k-wds\", split=\"validation\", streaming=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "24786dd0-2065-4ca9-adf5-9fb3d565b855",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"vit_base_patch16_224.augreg2_in21k_ft_in1k\"\n",
    "model = timm.create_model(model_name, pretrained=True).to(device)\n",
    "data_config = timm.data.resolve_model_data_config(model)\n",
    "transforms = timm.data.create_transform(**data_config, is_training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b830c093-c489-4edd-8aa7-9c94fa805792",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logits_and_labels_stream(\n",
    "    model: torch.nn.Module,\n",
    "    dataset: datasets.IterableDataset,\n",
    "    transforms,\n",
    "    cutoff: int = None,\n",
    "):\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "    logits, labels = [], []\n",
    "    seen = 0\n",
    "    with torch.no_grad():\n",
    "        for data in dataset:\n",
    "            img = data[\"jpg\"].convert(\"RGB\")\n",
    "            x, target = transforms(img).unsqueeze(0).to(device), data[\"cls\"]\n",
    "            logits.append(model(x))\n",
    "            labels.append(target)\n",
    "            seen += 1\n",
    "            if cutoff is not None and seen == cutoff:\n",
    "                break\n",
    "\n",
    "    logits = torch.cat(logits)\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "    return logits, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c761ed1e-b56f-4a5e-be3b-6dfd037e2c59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%time\n",
    "logits, labels = get_logits_and_labels_stream(model, dataset, transforms, cutoff=None)\n",
    "pickle.dump(logits, open(\"stored_logits/vit_logits.p\", \"wb\"))\n",
    "pickle.dump(labels, open(\"stored_logits/vit_labels.p\", \"wb\"))\n",
    "# logits = pickle.load(open(\"stored_logits/vit_logits.p\", \"rb\"))\n",
    "# labels = pickle.load(open(\"stored_logits/vit_labels.p\", \"rb\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea843733-7d34-439a-a528-461b9b74bbaa",
   "metadata": {},
   "source": [
    "## Compare Recalibration Strategies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "75872f2d-53b0-4c0e-87f2-a72c8e112c07",
   "metadata": {},
   "outputs": [],
   "source": [
    "cal_cutoff = int(0.2 * len(logits))\n",
    "\n",
    "cal_logits, cal_labels = logits[:cal_cutoff], labels[:cal_cutoff]\n",
    "cal_probs = torch.nn.functional.softmax(cal_logits, dim=1)\n",
    "\n",
    "test_logits, test_labels = logits[cal_cutoff:], labels[cal_cutoff:]\n",
    "test_probs = torch.nn.functional.softmax(test_logits, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "45f6806a-b1c8-477f-b43f-d6e66edd91fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_metrics(probs, labels):\n",
    "    confs = probs.max(dim=1).values.cpu().numpy()\n",
    "    bin_labels = (probs.argmax(dim=1) == labels).long().cpu().numpy()\n",
    "\n",
    "    acc = bin_labels.sum() / len(bin_labels) * 100\n",
    "    print(f\"Accuracy: {acc:.2f}\")\n",
    "\n",
    "    roc = roc_auc_score(bin_labels, confs) * 100\n",
    "    print(f\"ROC-AUC: {roc:.2f}\")\n",
    "    \n",
    "    ece = ECE(bins=15)\n",
    "    ece_val = ece.measure(confs, bin_labels) * 100\n",
    "    print(f\"Binned ECE: {ece_val:.2f}\")\n",
    "\n",
    "    ace = ACE(bins=15)\n",
    "    ace_val = ace.measure(confs, bin_labels) * 100\n",
    "    print(f\"Binned ACE: {ace_val:.2f}\")\n",
    "\n",
    "    smece = rp.smECE(confs, bin_labels) * 100\n",
    "    print(f\"Smooth ECE: {smece:.2f}\")\n",
    "\n",
    "    nll = torch.nn.NLLLoss()\n",
    "    nll_val = nll(torch.log(probs), labels) * 100\n",
    "    print(f\"NLL: {nll_val:.2f}\")\n",
    "    \n",
    "    one_hot_labels = torch.zeros_like(probs)\n",
    "    one_hot_labels[torch.arange(len(probs)), labels] = 1\n",
    "    bs_val = ((probs - one_hot_labels) ** 2).sum(dim=1).mean() * 100\n",
    "    print(f\"MSE: {bs_val:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b584b98e-a806-4fd4-b35d-c22e48cfa74d",
   "metadata": {},
   "source": [
    "### Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "81cbec58-04db-4e39-ae10-0a85b26e278d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 85.14\n",
      "ROC-AUC: 86.41\n",
      "Binned ECE: 9.42\n",
      "Binned ACE: 9.44\n",
      "Smooth ECE: 9.44\n",
      "NLL: 65.35\n",
      "MSE: 22.73\n"
     ]
    }
   ],
   "source": [
    "print_metrics(test_probs, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "836cf7e0-e602-4e11-9679-cd1247f8f30c",
   "metadata": {},
   "source": [
    "### Temperature Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "92c13800-2110-412e-84a3-dbc80755d22c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tscale = TemperatureScaler(init_T=1.5, device=device, use_mse=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98debade-8cf0-4e8b-b789-44ac5971270f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tscale.fit(cal_logits, cal_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "02513a17-838e-4f03-b41e-aff7accfa0e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_probs = tscale.predict_probs(test_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c153ee0a-943f-4f5f-9c72-a78f19af57b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 85.14\n",
      "ROC-AUC: 87.07\n",
      "Binned ECE: 2.70\n",
      "Binned ACE: 5.65\n",
      "Smooth ECE: 2.68\n",
      "NLL: 56.83\n",
      "MSE: 21.96\n"
     ]
    }
   ],
   "source": [
    "print_metrics(ts_probs, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15cc14da-6642-4c09-9968-b0ac8f2e3f69",
   "metadata": {},
   "source": [
    "### Histogram Binning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7ba1ad6c-3788-492d-b466-10384cf10148",
   "metadata": {},
   "outputs": [],
   "source": [
    "binner = HistogramBinning(bins=15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f6dd4a74-5bee-42f3-bf2f-96a36b85f0f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>HistogramBinning(_bin_bounds=None, _bin_map=None, _bins_internal=None,\n",
       "                 _default_independent_probabilities=False,\n",
       "                 _multiclass_instances=[(0,\n",
       "                                         HistogramBinning(_bin_bounds=[array([0.        , 0.06666667, 0.13333333, 0.2       , 0.26666667,\n",
       "       0.33333333, 0.4       , 0.46666667, 0.53333333, 0.6       ,\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.1       , 0.16666667, 0.2333333...\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.        , 0.16666667, 0.23333333, 0.3       ,\n",
       "       0.36666667, 0.43333333, 0.5       , 0.56666667, 0.63333333,\n",
       "       1.        , 0.76666667, 1.        , 1.        , 0.96666667]),\n",
       "                                                          _bins_internal=[15],\n",
       "                                                          _default_independent_probabilities=False,\n",
       "                                                          _multiclass_instances=[],\n",
       "                                                          _num_combination=0,\n",
       "                                                          bins=15,\n",
       "                                                          num_classes=2)), ...],\n",
       "                 _num_combination=0, bins=15, num_classes=1000)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator  sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label  sk-toggleable__label-arrow \">&nbsp;HistogramBinning<span class=\"sk-estimator-doc-link \">i<span>Not fitted</span></span></label><div class=\"sk-toggleable__content \"><pre>HistogramBinning(_bin_bounds=None, _bin_map=None, _bins_internal=None,\n",
       "                 _default_independent_probabilities=False,\n",
       "                 _multiclass_instances=[(0,\n",
       "                                         HistogramBinning(_bin_bounds=[array([0.        , 0.06666667, 0.13333333, 0.2       , 0.26666667,\n",
       "       0.33333333, 0.4       , 0.46666667, 0.53333333, 0.6       ,\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.1       , 0.16666667, 0.2333333...\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.        , 0.16666667, 0.23333333, 0.3       ,\n",
       "       0.36666667, 0.43333333, 0.5       , 0.56666667, 0.63333333,\n",
       "       1.        , 0.76666667, 1.        , 1.        , 0.96666667]),\n",
       "                                                          _bins_internal=[15],\n",
       "                                                          _default_independent_probabilities=False,\n",
       "                                                          _multiclass_instances=[],\n",
       "                                                          _num_combination=0,\n",
       "                                                          bins=15,\n",
       "                                                          num_classes=2)), ...],\n",
       "                 _num_combination=0, bins=15, num_classes=1000)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "HistogramBinning(_bin_bounds=None, _bin_map=None, _bins_internal=None,\n",
       "                 _default_independent_probabilities=False,\n",
       "                 _multiclass_instances=[(0,\n",
       "                                         HistogramBinning(_bin_bounds=[array([0.        , 0.06666667, 0.13333333, 0.2       , 0.26666667,\n",
       "       0.33333333, 0.4       , 0.46666667, 0.53333333, 0.6       ,\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.1       , 0.16666667, 0.2333333...\n",
       "       0.66666667, 0.73333333, 0.8       , 0.86666667, 0.93333333,\n",
       "       1.        ])],\n",
       "                                                          _bin_map=array([0.        , 0.        , 0.16666667, 0.23333333, 0.3       ,\n",
       "       0.36666667, 0.43333333, 0.5       , 0.56666667, 0.63333333,\n",
       "       1.        , 0.76666667, 1.        , 1.        , 0.96666667]),\n",
       "                                                          _bins_internal=[15],\n",
       "                                                          _default_independent_probabilities=False,\n",
       "                                                          _multiclass_instances=[],\n",
       "                                                          _num_combination=0,\n",
       "                                                          bins=15,\n",
       "                                                          num_classes=2)), ...],\n",
       "                 _num_combination=0, bins=15, num_classes=1000)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "binner.fit(cal_probs.cpu().numpy(), cal_labels.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "471a53c9-ee3a-4317-aa12-5946c426ed37",
   "metadata": {},
   "outputs": [],
   "source": [
    "binner_probs = torch.FloatTensor(binner.transform(test_probs.cpu().numpy())).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d8ad07ed-d60e-4703-9359-edb82bbc440d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 82.03\n",
      "ROC-AUC: 81.67\n",
      "Binned ECE: 5.06\n",
      "Binned ACE: 11.97\n",
      "Smooth ECE: 4.03\n",
      "NLL: inf\n",
      "MSE: 27.17\n"
     ]
    }
   ],
   "source": [
    "print_metrics(binner_probs, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "387fb1b6-ce97-40fa-a2d3-d5002541b238",
   "metadata": {},
   "source": [
    "### Isotonic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "84381b75-ffb4-49c1-b64d-33ad2504adcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "iso_reg = IsotonicRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b4c64b87-a0e0-473f-be44-49d6b97ad7a1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-2 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-2 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-2 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-2 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-2 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-2 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-2 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>IsotonicRegression(_default_independent_probabilities=False, _iso=None,\n",
       "                   _multiclass_instances=[(0,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (1,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(...\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (29,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)), ...],\n",
       "                   num_classes=1000)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator  sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label  sk-toggleable__label-arrow \">&nbsp;IsotonicRegression<span class=\"sk-estimator-doc-link \">i<span>Not fitted</span></span></label><div class=\"sk-toggleable__content \"><pre>IsotonicRegression(_default_independent_probabilities=False, _iso=None,\n",
       "                   _multiclass_instances=[(0,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (1,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(...\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (29,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds=&#x27;clip&#x27;,\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)), ...],\n",
       "                   num_classes=1000)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "IsotonicRegression(_default_independent_probabilities=False, _iso=None,\n",
       "                   _multiclass_instances=[(0,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds='clip',\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (1,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(...\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds='clip',\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)),\n",
       "                                          (29,\n",
       "                                           IsotonicRegression(_default_independent_probabilities=False,\n",
       "                                                              _iso=IsotonicRegression(out_of_bounds='clip',\n",
       "                                                                                      y_max=1.0,\n",
       "                                                                                      y_min=0.0),\n",
       "                                                              _multiclass_instances=[],\n",
       "                                                              num_classes=2)), ...],\n",
       "                   num_classes=1000)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iso_reg.fit(cal_probs.cpu().numpy(), cal_labels.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "cb95288b-a541-425b-9cce-a00e77c972ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "iso_probs = torch.FloatTensor(iso_reg.transform(test_probs.cpu().numpy())).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "aefda0bc-66e9-4538-81e3-689f261a0072",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 83.82\n",
      "ROC-AUC: 85.95\n",
      "Binned ECE: 4.53\n",
      "Binned ACE: 8.13\n",
      "Smooth ECE: 4.48\n",
      "NLL: inf\n",
      "MSE: 24.13\n"
     ]
    }
   ],
   "source": [
    "print_metrics(iso_probs, test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e26683b-5392-4902-9f91-f4505704826b",
   "metadata": {},
   "source": [
    "### Mean Replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "093f8c7a-9b4a-4305-9b18-10e60688a579",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_cal_prob = (cal_probs.argmax(dim=1) == cal_labels).float().mean()\n",
    "test_preds = test_probs.argmax(dim=1)\n",
    "mr_probs = (1 - mean_cal_prob) / (test_probs.shape[1] - 1) * torch.ones_like(test_probs)\n",
    "mr_probs[torch.arange(len(test_probs)), test_preds] = mean_cal_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a6a76d7b-2185-4cae-8bc5-931f58bc256d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 85.14\n",
      "ROC-AUC: 50.00\n",
      "Binned ECE: 0.18\n",
      "Binned ACE: 0.18\n",
      "Smooth ECE: 0.18\n",
      "NLL: 144.66\n",
      "MSE: 27.51\n"
     ]
    }
   ],
   "source": [
    "print_metrics(mr_probs, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f3ea112a-8986-4d83-ad0b-e6d375659220",
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_probs = np.ones(100)\n",
    "zeros = np.zeros(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "499deaec-0547-4871-9a67-912ebac6bd3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/users/mc696/.local/lib/python3.10/site-packages/sklearn/metrics/_ranking.py:1146: UndefinedMetricWarning: No positive samples in y_true, true positive value should be meaningless\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([0., 1.]), array([nan, nan]), array([inf,  1.]))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_curve(zeros, rand_probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf649c5-1852-4b22-8f47-e0f6a174f3fd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
