{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71e4b6e8-32b7-4b85-8680-7436b3ab9104",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "3cfa622a-75e7-4196-9b94-31113cae4f54",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=0/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.58it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=1/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.59it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 28.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=2/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.76it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 28.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=3/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.62it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=4/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 17.87it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=5/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.88it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=6/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.82it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=7/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.76it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=8/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.05it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=9/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.03it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=10/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.64it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=11/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.32it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=12/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 18.79it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=13/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 28.31it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=14/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.55it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=15/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.19it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=16/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.41it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=17/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.62it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=18/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.52it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=19/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.46it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=20/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.96it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=21/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.00it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=22/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.19it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=23/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.36it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=24/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.89it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=25/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.26it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=26/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.27it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=27/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.43it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=28/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.59it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=29/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.78it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.68it/s]\n"
     ]
    }
   ],
   "source": [
    "n = 75\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "n_S_seeds = 30\n",
    "epoch = 20 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'] ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 284,
   "id": "0932885f-337d-4216-a981-56026b01f2df",
   "metadata": {},
   "outputs": [],
   "source": [
    "def digitize(x, bins):\n",
    "    \"\"\"\n",
    "    Similar to np.digitize(x, bins, right=False) - 1, but the last bin unbounded on the right\n",
    "    e.g. bins = [0, 0.5, 1] -> bin 0: [0, 0.5), bin 1: [0.5, inf]\n",
    "    if the input is a number, return a number\n",
    "    \"\"\"\n",
    "    binids = np.minimum(np.digitize(x, bins), len(bins) - 1)\n",
    "    binids -= 1\n",
    "    if isinstance(x, (int, float)):\n",
    "        binids = binids.item()\n",
    "    return binids\n",
    "\n",
    "import warnings\n",
    "class BaseCalibrator:\n",
    "    def fit(self, z, y):\n",
    "        return self\n",
    "\n",
    "    def predict(self, z):\n",
    "        pass\n",
    "\n",
    "    def __call__(self, z):\n",
    "        return self.predict(z)\n",
    "\n",
    "    def plot(self, *args, ax=None, set_layout=None, **kwargs):\n",
    "        if ax is None:\n",
    "            if plt.get_fignums():\n",
    "                ax = plt.gca()\n",
    "            else:\n",
    "                _, ax = plt.subplots(figsize=(3.5, 3.5))\n",
    "\n",
    "        if hasattr(self, 'bins'):\n",
    "            ax.stairs(self.bin_mean, *args, edges=self.bins, baseline=None, **kwargs)\n",
    "            ax.set_ylim(bottom=-0.05)\n",
    "        elif hasattr(self, 'z_'):\n",
    "            ax.plot(self.z_, self.y_, *args, **kwargs)\n",
    "        else:\n",
    "            zz = np.linspace(0, 1, 1000)\n",
    "            ax.plot(zz, self.predict(zz), *args, **kwargs)\n",
    "\n",
    "        if set_layout:\n",
    "            ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)\n",
    "            plt.xlabel('Predicted Probability')\n",
    "            plt.ylabel('Empirical Frequency')\n",
    "            ax.set_aspect('equal', 'box')  # plt.axis('scaled'): ylim bottom change back to 0\n",
    "            plt.grid()\n",
    "            plt.legend(framealpha=0.4)\n",
    "        return ax\n",
    "\n",
    "class HistogramCalibrator(BaseCalibrator):\n",
    "    def __init__(self, n_bins=10, strategy='uniform'):\n",
    "        self.n_bins = n_bins\n",
    "        self.strategy = strategy\n",
    "        self.bins = None  # edges\n",
    "        self.bin_mean = None  # heights\n",
    "\n",
    "    def fit(self, z, y):\n",
    "        if self.strategy == 'uniform':\n",
    "            bins = np.linspace(0, 1, self.n_bins + 1)\n",
    "        elif self.strategy == 'quantile':\n",
    "            bins = np.quantile(z, np.linspace(0, 1, self.n_bins + 1))\n",
    "        else:\n",
    "            raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "        bins[0], bins[-1] = 0., 1.\n",
    "        binids = digitize(z, bins)\n",
    "        print(len(binids), len(y))\n",
    "        bin_total = np.bincount(binids, minlength=self.n_bins) ## the number of samples per bins\n",
    "        bin_true = np.bincount(binids, weights=y, minlength=self.n_bins) ## the number of samples per bins weighted by labels\n",
    "        with warnings.catch_warnings():\n",
    "            warnings.filterwarnings('ignore')\n",
    "            # fill nan by interpolation assuming smoothness\n",
    "            self.bin_mean = interpolate_nan(bin_true / bin_total) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "            self.frac_data = bin_total / len(z)\n",
    "            \n",
    "        self.bins = bins\n",
    "        return self\n",
    "\n",
    "    def predict(self, z):\n",
    "        binids = digitize(z, self.bins)\n",
    "        return self.bin_mean[binids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 285,
   "id": "8032ca91-d827-4ee2-aa13-0f21949cca9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import softmax\n",
    "import warnings\n",
    "\n",
    "def multiclass_calibration_error(preds, target, num_classes, n_bins=15, ignore_index=None, norm='l1', method='uniform'):\n",
    "    \"\"\"\n",
    "    Final code to compute multiclass calibration error\n",
    "    \"\"\"\n",
    "    preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False)\n",
    "    confidences, accuracies = _multiclass_calibration_error_update(preds, target)\n",
    "    return _ce_compute(confidences, accuracies, n_bins, norm, method, target)\n",
    "\n",
    "def _multiclass_confusion_matrix_format(preds, target, ignore_index=None, convert_to_labels=True, method='uniform'):\n",
    "    \"\"\"\n",
    "    Formatting the predictive values and labels to calculate the calibration error\n",
    "    \"\"\"\n",
    "    # Apply argmax if we have one more dimension\n",
    "    if preds.ndim == target.ndim + 1 and convert_to_labels:\n",
    "        preds = preds.argmax(dim=1)\n",
    "    preds = preds.flatten() if convert_to_labels else torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1])\n",
    "    target = target.flatten()\n",
    "    \n",
    "    if ignore_index is not None:\n",
    "        idx = target != ignore_index\n",
    "        preds = preds[idx]\n",
    "        target = target[idx]\n",
    "\n",
    "    return preds, target\n",
    "\n",
    "def _multiclass_calibration_error_update(preds, target):\n",
    "    \"\"\"\n",
    "    Preparing the predictive probability vectors and accuracies for CE calculation\n",
    "    \"\"\"\n",
    "    if not torch.all((preds >= 0) * (preds <= 1)):\n",
    "        preds = preds.softmax(1) ## convert to prob. vector if preds are not probabilistic outputs\n",
    "    confidences, predictions = preds.max(dim=1)\n",
    "    accuracies = predictions.eq(target)\n",
    "    return confidences.float(), accuracies.float()\n",
    "\n",
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    #b = a.copy()\n",
    "    b = np.copy(a)\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b)\n",
    "\n",
    "#def _recalibration(confidences, bin_boundaries, indices, target, target_val):\n",
    "#    \"\"\"\n",
    "#    Implementing the calibration value when we use UMB (see Eq.(9)).\n",
    "#    --> 方針：まずbin_mean（\\hat{mu}）を計算し，predictという形で予測確率をrecalibrationし，そのもとでacc_bin，conf_binなどを計算させる．\n",
    "#    -->  construct a \"recalibration function\" h(z), where z = f(X) (pred. prob.)\n",
    "#    \"\"\"\n",
    "#    calibrator = HistogramCalibrator(n_bins=bin_boundaries, strategy='quantile')\n",
    "#    calibrator.fit(confidences_val, target_val) ## fitting to the predictive condifence on the \"calibration data\".\n",
    "#    recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n",
    "    \n",
    "    #torch.tensor(calibrator.predict(torch.max(p,1).values))\n",
    "    #bin_total = torch.bincount(indices, minlength=len(bin_boundaries)-1).to(confidences.device) ## the number of samples per bins\n",
    "    #if target == None:\n",
    "    #    raise ValueError('target labels are required.')\n",
    "    #else:\n",
    "    #    bin_true = torch.bincount(indices, weights=target, minlength=len(bin_boundaries)-1).float().to(confidences.device) ## the number of samples per bins weighted by labels        \n",
    "    #with warnings.catch_warnings():\n",
    "    #    warnings.filterwarnings('ignore')\n",
    "    #    bin_mean = interpolate_nan(bin_true / bin_total) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "\n",
    "    #recab_conf = bin_mean[indices]\n",
    "    \n",
    "#    return recab_conf, calibrator\n",
    "        \n",
    "\n",
    "def _binning_bucketize(confidences, accuracies, bin_boundaries, method='uniform', target=None):\n",
    "    \"\"\"\n",
    "    Calculate the contents count, accuracies, confidences, and proportions per bins.\n",
    "    \"\"\"\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    acc_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "\n",
    "    indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1\n",
    "    if method == 'quantile':\n",
    "        ## confidence via recalibration\n",
    "        confidences, recalibrator = _recalibration(confidences, bin_boundaries, indices, target)\n",
    "\n",
    "    count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))\n",
    "    conf_bin.scatter_add_(dim=0, index=indices, src=confidences)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "    return acc_bin, conf_bin, prop_bin\n",
    "\n",
    "def _ce_compute(confidences, accuracies, bin_boundaries, norm='l1', method='uniform', target=None):\n",
    "    \"\"\"\n",
    "    Computing CE values\n",
    "    \"\"\"\n",
    "    ## bin setting (uniform or uniform-mass)\n",
    "    if isinstance(bin_boundaries, int):\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=confidences.dtype, device=confidences.device)\n",
    "            bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        elif method == 'quantile':\n",
    "            bin_boundaries = torch.tensor(np.quantile(confidences, np.linspace(0, 1, bin_boundaries + 1)), dtype=confidences.dtype, device=confidences.device)\n",
    "            bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        else:\n",
    "            raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "\n",
    "    with torch.no_grad():\n",
    "        acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries, method, target)\n",
    "\n",
    "    if norm == \"l1\":\n",
    "        ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin)\n",
    "    if norm == \"max\":\n",
    "        ce = torch.max(torch.abs(acc_bin - conf_bin))\n",
    "    if norm == \"l2\":\n",
    "        ce = torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin))\n",
    "\n",
    "    return ce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "f219d856-bbab-416e-9f18-d8d22977bd04",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multiclass_recalibration_error(preds, preds_val, target, target_val, num_classes, n_bins=15, ignore_index=None, norm='l1', method='uniform'):\n",
    "\n",
    "    #preds, target, preds_ignore, target_ignore = _multiclass_confusion_matrix_format_recalibration(preds, target, ignore_index, convert_to_labels=False, method='uniform')\n",
    "    preds, target = _multiclass_confusion_matrix_format_recalibration(preds, target, ignore_index, convert_to_labels=False)\n",
    "    preds_val, target_val = _multiclass_confusion_matrix_format_recalibration(preds_val, target_val, ignore_index, convert_to_labels=False)\n",
    "    confidences, accuracies = _multiclass_calibration_error_update(preds, target)\n",
    "    confidences_val, accuracies_val = _multiclass_calibration_error_update(preds_val, target_val)\n",
    "\n",
    "    return _ce_compute_recalibrate(confidences, confidences_val, accuracies, accuracies_val, n_bins, norm, method, target, target_val)\n",
    "\n",
    "def _multiclass_confusion_matrix_format_recalibration(preds, target, ignore_index=None, convert_to_labels=True):\n",
    "    \"\"\"\n",
    "    Formatting the predictive values and labels to calculate the calibration error\n",
    "    \"\"\"\n",
    "    # Apply argmax if we have one more dimension\n",
    "    if preds.ndim == target.ndim + 1 and convert_to_labels:\n",
    "        preds = preds.argmax(dim=1)\n",
    "    preds = preds.flatten() if convert_to_labels else torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1])\n",
    "    target = target.flatten()\n",
    "\n",
    "    if ignore_index is not None:\n",
    "        idx = target != ignore_index\n",
    "        preds = preds[idx]\n",
    "        target = target[idx]\n",
    "\n",
    "    #return preds, target, preds_ignore, target_ignore\n",
    "    return preds, target\n",
    "\n",
    "def _recalibration(confidences, confidences_val, bin_boundaries, target, target_val):\n",
    "    \"\"\"\n",
    "    Implementing the calibration value when we use UMB (see Eq.(9)).\n",
    "    --> 方針：まずbin_mean（\\hat{mu}）を計算し，predictという形で予測確率をrecalibrationし，そのもとでacc_bin，conf_binなどを計算させる．\n",
    "    -->  construct a \"recalibration function\" h(z), where z = f(X) (pred. prob.)\n",
    "    \"\"\"\n",
    "    calibrator = HistogramCalibrator(n_bins=bin_boundaries, strategy='quantile')\n",
    "    #print(target_val.shape)\n",
    "    print(confidences_val.shape, target_val.shape)\n",
    "    calibrator.fit(confidences_val, target_val) ## fitting to the predictive condifence on the \"calibration data\".\n",
    "    recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n",
    "    bin_boundaries = torch.tensor(calibrator.bins).float().to(confidences.dtype)\n",
    "    \n",
    "    #torch.tensor(calibrator.predict(torch.max(p,1).values))\n",
    "    #bin_total = torch.bincount(indices, minlength=len(bin_boundaries)-1).to(confidences.device) ## the number of samples per bins\n",
    "    #if target == None:\n",
    "    #    raise ValueError('target labels are required.')\n",
    "    #else:\n",
    "    #    bin_true = torch.bincount(indices, weights=target, minlength=len(bin_boundaries)-1).float().to(confidences.device) ## the number of samples per bins weighted by labels        \n",
    "    #with warnings.catch_warnings():\n",
    "    #    warnings.filterwarnings('ignore')\n",
    "    #    bin_mean = interpolate_nan(bin_true / bin_total) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "\n",
    "    #recab_conf = bin_mean[indices]\n",
    "    \n",
    "    return recab_conf, bin_boundaries, calibrator\n",
    "        \n",
    "\n",
    "def _binning_bucketize_recalibrate(confidences, confidences_val, accuracies, accuracies_val, bin_boundaries, method='uniform', target=None, target_val=None):\n",
    "    \"\"\"\n",
    "    Calculate the contents count, accuracies, confidences, and proportions per bins.\n",
    "    \"\"\"\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    if method == 'quantile':\n",
    "        ## confidence via recalibration\n",
    "        confidences, bin_boundaries, recalibrator = _recalibration(confidences, confidences_val, bin_boundaries, target, target_val)\n",
    "\n",
    "    indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1\n",
    "    \n",
    "    acc_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "\n",
    "\n",
    "    count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))\n",
    "    conf_bin.scatter_add_(dim=0, index=indices, src=confidences)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "    return acc_bin, conf_bin, prop_bin\n",
    "\n",
    "def _ce_compute_recalibrate(confidences, accuracies, confidences_val, accuracies_val, bin_boundaries, norm='l1', method='uniform', target=None, target_val=None):\n",
    "    \"\"\"\n",
    "    Computing CE values\n",
    "    \"\"\"\n",
    "    ## bin setting (uniform or uniform-mass)\n",
    "    if isinstance(bin_boundaries, int):\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=confidences.dtype, device=confidences.device)\n",
    "            bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        elif method == 'quantile':\n",
    "            pass\n",
    "            #bin_boundaries = torch.tensor(np.quantile(confidences, np.linspace(0, 1, bin_boundaries + 1)), dtype=confidences.dtype, device=confidences.device)\n",
    "            #bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        else:\n",
    "            raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "\n",
    "    with torch.no_grad():\n",
    "        acc_bin, conf_bin, prop_bin = _binning_bucketize_recalibrate(confidences, confidences_val, accuracies, accuracies_val, bin_boundaries, method, target, target_val)\n",
    "\n",
    "    if norm == \"l1\":\n",
    "        ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin)\n",
    "    if norm == \"max\":\n",
    "        ce = torch.max(torch.abs(acc_bin - conf_bin))\n",
    "    if norm == \"l2\":\n",
    "        ce = torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin))\n",
    "\n",
    "    return ce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca8a28a8-09ab-40b2-849c-ee19a6f73cb1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "id": "8a021908-79b1-4592-a636-9a36b6337584",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0025)\n",
      "tensor(0.0025)\n",
      "tensor(0.0025)\n"
     ]
    }
   ],
   "source": [
    "labels = [y for x, y in all_examples]\n",
    "labels = torch.tensor(labels).long()\n",
    "indices = 2*np.arange(len(masks[0])) + masks[0] ## train dataset indices from supersamples\n",
    "indices_test = 2*np.arange(len(masks[0])) + (1 - masks[0]) ## test dataset indices from supersamples\n",
    "n_class = preds[0].shape[1]\n",
    "n_bins = 4\n",
    "\n",
    "print(multiclass_calibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='l1',method='uniform'))\n",
    "print(multiclass_calibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='uniform'))\n",
    "print(multiclass_calibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='uniform'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3badc5b-bb09-4a89-98f8-6458ae9825dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "id": "6d2ae95e-139d-46b1-830a-d1c609df9d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "preds_test, preds_val, labels_test, labels_val = train_test_split(preds[0][indices_test], labels[indices_test], test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "9098224f-bb13-4fd1-9135-bda994517aa9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([75]) torch.Size([75])\n",
      "75 75\n",
      "tensor(0.7174)\n",
      "torch.Size([75]) torch.Size([75])\n",
      "75 75\n",
      "tensor(0.7174)\n",
      "torch.Size([75]) torch.Size([75])\n",
      "75 75\n",
      "tensor(0.7174)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m_/mrnp6k153dq_mxbk9450fv5c0000gn/T/ipykernel_54108/3587375794.py:39: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n",
      "/var/folders/m_/mrnp6k153dq_mxbk9450fv5c0000gn/T/ipykernel_54108/3587375794.py:39: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n",
      "/var/folders/m_/mrnp6k153dq_mxbk9450fv5c0000gn/T/ipykernel_54108/3587375794.py:39: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n"
     ]
    }
   ],
   "source": [
    "print(multiclass_recalibration_error(preds[0][indices_test], preds[0][indices_test], labels[indices_test], labels[indices_test], num_classes=n_class, n_bins=n_bins, norm='l1',method='quantile'))\n",
    "print(multiclass_recalibration_error(preds[0][indices_test], preds[0][indices_test], labels[indices_test], labels[indices_test], num_classes=n_class, n_bins=n_bins, norm='max',method='quantile'))\n",
    "print(multiclass_recalibration_error(preds[0][indices_test], preds[0][indices_test], labels[indices_test], labels[indices_test], num_classes=n_class, n_bins=n_bins, norm='l2',method='quantile'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "id": "e8de9658-e96d-4c5e-a4ba-8286965fd6a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([60]) torch.Size([15])\n",
      "60 15\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "The weights and list don't have the same length.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[307], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mmulticlass_recalibration_error\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreds_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreds_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_classes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_class\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_bins\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_bins\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ml1\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mquantile\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#print(multiclass_recalibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='quantile'))\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m#print(multiclass_recalibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='quantile'))\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[286], line 9\u001b[0m, in \u001b[0;36mmulticlass_recalibration_error\u001b[0;34m(preds, preds_val, target, target_val, num_classes, n_bins, ignore_index, norm, method)\u001b[0m\n\u001b[1;32m      6\u001b[0m confidences, accuracies \u001b[38;5;241m=\u001b[39m _multiclass_calibration_error_update(preds, target)\n\u001b[1;32m      7\u001b[0m confidences_val, accuracies_val \u001b[38;5;241m=\u001b[39m _multiclass_calibration_error_update(preds_val, target_val)\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ce_compute_recalibrate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfidences_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_bins\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_val\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[286], line 100\u001b[0m, in \u001b[0;36m_ce_compute_recalibrate\u001b[0;34m(confidences, accuracies, confidences_val, accuracies_val, bin_boundaries, norm, method, target, target_val)\u001b[0m\n\u001b[1;32m     97\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstrategy should be either \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muniform\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     99\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 100\u001b[0m     acc_bin, conf_bin, prop_bin \u001b[38;5;241m=\u001b[39m \u001b[43m_binning_bucketize_recalibrate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfidences_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbin_boundaries\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_val\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    102\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m norm \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ml1\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    103\u001b[0m     ce \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msum(torch\u001b[38;5;241m.\u001b[39mabs(acc_bin \u001b[38;5;241m-\u001b[39m conf_bin) \u001b[38;5;241m*\u001b[39m prop_bin)\n",
      "Cell \u001b[0;32mIn[286], line 64\u001b[0m, in \u001b[0;36m_binning_bucketize_recalibrate\u001b[0;34m(confidences, confidences_val, accuracies, accuracies_val, bin_boundaries, method, target, target_val)\u001b[0m\n\u001b[1;32m     61\u001b[0m accuracies \u001b[38;5;241m=\u001b[39m accuracies\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mconfidences\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m     62\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m     63\u001b[0m     \u001b[38;5;66;03m## confidence via recalibration\u001b[39;00m\n\u001b[0;32m---> 64\u001b[0m     confidences, bin_boundaries, recalibrator \u001b[38;5;241m=\u001b[39m \u001b[43m_recalibration\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfidences_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbin_boundaries\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_val\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     66\u001b[0m indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mbucketize(confidences, bin_boundaries, right\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     68\u001b[0m acc_bin \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;28mlen\u001b[39m(bin_boundaries), device\u001b[38;5;241m=\u001b[39mconfidences\u001b[38;5;241m.\u001b[39mdevice, dtype\u001b[38;5;241m=\u001b[39mconfidences\u001b[38;5;241m.\u001b[39mdtype)\n",
      "Cell \u001b[0;32mIn[286], line 38\u001b[0m, in \u001b[0;36m_recalibration\u001b[0;34m(confidences, confidences_val, bin_boundaries, target, target_val)\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[38;5;66;03m#print(target_val.shape)\u001b[39;00m\n\u001b[1;32m     37\u001b[0m \u001b[38;5;28mprint\u001b[39m(confidences_val\u001b[38;5;241m.\u001b[39mshape, target_val\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m---> 38\u001b[0m \u001b[43mcalibrator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences_val\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_val\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m## fitting to the predictive condifence on the \"calibration data\".\u001b[39;00m\n\u001b[1;32m     39\u001b[0m recab_conf \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(calibrator\u001b[38;5;241m.\u001b[39mpredict(confidences))\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mto(confidences\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     40\u001b[0m bin_boundaries \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(calibrator\u001b[38;5;241m.\u001b[39mbins)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mto(confidences\u001b[38;5;241m.\u001b[39mdtype)\n",
      "Cell \u001b[0;32mIn[284], line 67\u001b[0m, in \u001b[0;36mHistogramCalibrator.fit\u001b[0;34m(self, z, y)\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(binids), \u001b[38;5;28mlen\u001b[39m(y))\n\u001b[1;32m     66\u001b[0m bin_total \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mbincount(binids, minlength\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_bins) \u001b[38;5;66;03m## the number of samples per bins\u001b[39;00m\n\u001b[0;32m---> 67\u001b[0m bin_true \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbincount\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbinids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mminlength\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_bins\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m## the number of samples per bins weighted by labels\u001b[39;00m\n\u001b[1;32m     68\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m warnings\u001b[38;5;241m.\u001b[39mcatch_warnings():\n\u001b[1;32m     69\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "\u001b[0;31mValueError\u001b[0m: The weights and list don't have the same length."
     ]
    }
   ],
   "source": [
    "## TODO: check why the splite size is changed...\n",
    "print(multiclass_recalibration_error(preds_test, preds_val, labels_test, labels_val, num_classes=n_class, n_bins=n_bins, norm='l1',method='quantile'))\n",
    "#print(multiclass_recalibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='quantile'))\n",
    "#print(multiclass_recalibration_error(preds[0][indices], labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='quantile'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "id": "f0ea78fb-fe63-4655-b3de-6a800fd1e835",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 308,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels[indices] != ignore_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "id": "0afe627f-0570-46d5-90df-a073a64c92b2",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "_recalibration() missing 1 required positional argument: 'target_val'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[309], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mmulticlass_calibration_error\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreds\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m[\u001b[49m\u001b[43mindices\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_classes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_class\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_bins\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_bins\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43ml1\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mquantile\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28mprint\u001b[39m(multiclass_calibration_error(preds[\u001b[38;5;241m0\u001b[39m][indices],labels[indices], num_classes\u001b[38;5;241m=\u001b[39mn_class, n_bins\u001b[38;5;241m=\u001b[39mn_bins, norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax\u001b[39m\u001b[38;5;124m'\u001b[39m,method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(multiclass_calibration_error(preds[\u001b[38;5;241m0\u001b[39m][indices],labels[indices], num_classes\u001b[38;5;241m=\u001b[39mn_class, n_bins\u001b[38;5;241m=\u001b[39mn_bins, norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ml2\u001b[39m\u001b[38;5;124m'\u001b[39m,method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m'\u001b[39m))\n",
      "Cell \u001b[0;32mIn[285], line 10\u001b[0m, in \u001b[0;36mmulticlass_calibration_error\u001b[0;34m(preds, target, num_classes, n_bins, ignore_index, norm, method)\u001b[0m\n\u001b[1;32m      8\u001b[0m preds, target \u001b[38;5;241m=\u001b[39m _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m      9\u001b[0m confidences, accuracies \u001b[38;5;241m=\u001b[39m _multiclass_calibration_error_update(preds, target)\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ce_compute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_bins\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnorm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[285], line 115\u001b[0m, in \u001b[0;36m_ce_compute\u001b[0;34m(confidences, accuracies, bin_boundaries, norm, method, target)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstrategy should be either \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muniform\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 115\u001b[0m     acc_bin, conf_bin, prop_bin \u001b[38;5;241m=\u001b[39m \u001b[43m_binning_bucketize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccuracies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbin_boundaries\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m norm \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ml1\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    118\u001b[0m     ce \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msum(torch\u001b[38;5;241m.\u001b[39mabs(acc_bin \u001b[38;5;241m-\u001b[39m conf_bin) \u001b[38;5;241m*\u001b[39m prop_bin)\n",
      "Cell \u001b[0;32mIn[285], line 87\u001b[0m, in \u001b[0;36m_binning_bucketize\u001b[0;34m(confidences, accuracies, bin_boundaries, method, target)\u001b[0m\n\u001b[1;32m     84\u001b[0m indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mbucketize(confidences, bin_boundaries, right\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     85\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m     86\u001b[0m     \u001b[38;5;66;03m## confidence via recalibration\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m     confidences, recalibrator \u001b[38;5;241m=\u001b[39m \u001b[43m_recalibration\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfidences\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbin_boundaries\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     89\u001b[0m count_bin\u001b[38;5;241m.\u001b[39mscatter_add_(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, index\u001b[38;5;241m=\u001b[39mindices, src\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mones_like(confidences))\n\u001b[1;32m     90\u001b[0m conf_bin\u001b[38;5;241m.\u001b[39mscatter_add_(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, index\u001b[38;5;241m=\u001b[39mindices, src\u001b[38;5;241m=\u001b[39mconfidences)\n",
      "\u001b[0;31mTypeError\u001b[0m: _recalibration() missing 1 required positional argument: 'target_val'"
     ]
    }
   ],
   "source": [
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l1',method='quantile'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='quantile'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='quantile'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 292,
   "id": "a2559868-3d5b-45b6-8917-e1a90f7e888c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0025)\n",
      "tensor(0.0025)\n",
      "tensor(0.0025)\n"
     ]
    }
   ],
   "source": [
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=\"l1\") ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)\n",
    "\n",
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=\"max\") ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)\n",
    "\n",
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm='l2') ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 310,
   "id": "f6f32975-9ed7-4673-9001-4b2e8affd398",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidences, accuracies = _multiclass_calibration_error_update(preds[0][indices], labels[indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "id": "3ba0fa2d-2bef-407c-af85-7c4c70d37ef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_boundaries = torch.tensor(np.quantile(confidences, np.linspace(0, 1, n_bins + 1)), dtype=confidences.dtype, device=confidences.device)\n",
    "bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 312,
   "id": "003959d5-21b4-4ec6-a67a-f2b357062c1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0000, 0.9948, 0.9968, 0.9981, 1.0000])\n",
      "tensor([3, 0, 2, 3, 0, 0, 3, 2, 2, 3, 0, 2, 3, 2, 0, 3, 3, 2, 2, 0, 3, 2, 3, 3,\n",
      "        3, 3, 3, 0, 3, 2, 3, 2, 0, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 0, 3, 2,\n",
      "        3, 3, 0, 2, 2, 0, 3, 2, 3, 0, 3, 2, 2, 0, 0, 0, 3, 3, 3, 3, 2, 0, 3, 0,\n",
      "        2, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "print(bin_boundaries)\n",
    "print(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 313,
   "id": "f001a4c3-4589-49dc-a83e-687780a99d97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75 75\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m_/mrnp6k153dq_mxbk9450fv5c0000gn/T/ipykernel_54108/2339584990.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n"
     ]
    }
   ],
   "source": [
    "calibrator = HistogramCalibrator(n_bins=4, strategy='quantile')\n",
    "calibrator.fit(confidences, labels[indices]) ## fitting to the predictive condifence on the \"calibration data\".\n",
    "recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "id": "5cc35ade-f191-498a-a783-27c923ba5570",
   "metadata": {},
   "outputs": [],
   "source": [
    "#bin_boundaries = torch.tensor(np.quantile(recab_conf, np.linspace(0, 1, n_bins + 1)), dtype=recab_conf.dtype, device=recab_conf.device)\n",
    "#bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "bin_boundaries = torch.tensor(calibrator.bins).float().to(confidences.dtype)\n",
    "indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 315,
   "id": "f3ad82a5-8bc2-429b-a87d-d77e2d38b79f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0000, 0.9948, 0.9968, 0.9981, 1.0000])\n",
      "tensor([3, 0, 2, 3, 0, 0, 3, 2, 2, 3, 0, 2, 3, 2, 0, 3, 3, 2, 2, 0, 3, 2, 3, 3,\n",
      "        3, 3, 3, 0, 3, 2, 3, 2, 0, 3, 0, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 0, 3, 2,\n",
      "        3, 3, 0, 2, 2, 0, 3, 2, 3, 0, 3, 2, 2, 0, 0, 0, 3, 3, 3, 3, 2, 0, 3, 0,\n",
      "        2, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "print(bin_boundaries)\n",
    "print(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "324ab414-a783-4905-ab35-020f08e9b183",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_test, target_test = _multiclass_confusion_matrix_format_recalibration(preds_test, labels_test, convert_to_labels=False)\n",
    "preds_val, target_val = _multiclass_confusion_matrix_format_recalibration(preds_val, labels_val, convert_to_labels=False)\n",
    "#confidences, accuracies = _multiclass_calibration_error_update(preds, target)\n",
    "#confidences_val, accuracies_val = _multiclass_calibration_error_update(preds_val, target_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 317,
   "id": "e89a4bb8-9963-45e0-8a6e-2b6e70266e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidences, accuracies = _multiclass_calibration_error_update(preds_test, target_test)\n",
    "confidences_val, accuracies_val = _multiclass_calibration_error_update(preds_val, target_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 318,
   "id": "eabde6d1-c38e-4f39-9e77-3acd5d80deef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m_/mrnp6k153dq_mxbk9450fv5c0000gn/T/ipykernel_54108/1884898717.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n"
     ]
    }
   ],
   "source": [
    "calibrator = HistogramCalibrator(n_bins=4, strategy='quantile')\n",
    "calibrator.fit(confidences_val, target_val) ## fitting to the predictive condifence on the \"calibration data\".\n",
    "recab_conf = torch.tensor(calibrator.predict(confidences)).float().to(confidences.device)\n",
    "bin_boundaries = torch.tensor(calibrator.bins).float().to(confidences.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 319,
   "id": "b5ca0bcf-a982-4ac7-a1c6-ad81b40f4066",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.3333, 0.5000, 0.3333,\n",
       "        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,\n",
       "        0.5000, 0.5000, 0.3333, 0.3333, 0.3333, 0.5000, 0.5000, 0.5000, 0.5000,\n",
       "        0.3333, 0.3333, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,\n",
       "        0.3333, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.3333, 0.5000, 0.5000,\n",
       "        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,\n",
       "        0.5000, 0.3333, 0.5000, 0.5000, 0.5000, 0.5000])"
      ]
     },
     "execution_count": 319,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recab_conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "id": "8193acc5-fe11-4af7-8967-b2489451a414",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9986, 0.9992, 0.9557, 0.9984, 0.9982, 0.5207, 0.9718, 0.9972, 0.9892,\n",
       "        0.9970, 0.8501, 0.9963, 0.6515, 0.9922, 0.9936, 0.8849, 0.9410, 0.9936,\n",
       "        0.9981, 0.9953, 0.9918, 0.9912, 0.9870, 0.9944, 0.9968, 0.9679, 0.9996,\n",
       "        0.9809, 0.9740, 0.9291, 0.9980, 0.9939, 0.5009, 0.9361, 0.9944, 0.9990,\n",
       "        0.9830, 0.9978, 0.9992, 0.8117, 0.9959, 0.9935, 0.9900, 0.9978, 0.9314,\n",
       "        0.9635, 0.9374, 0.9972, 0.9984, 0.9646, 0.9928, 0.6130, 0.9992, 0.9966,\n",
       "        0.9988, 0.9781, 0.9961, 0.9966, 0.5734, 0.5480])"
      ]
     },
     "execution_count": 320,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confidences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82de81ec-c440-4352-a569-bdcf9836e824",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
