{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e008aff-37e1-4d6a-b963-0392f30de6c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ff01e793-b75b-45ef-a4bd-7a4ff2f595df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=0/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 23.73it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=1/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.56it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=2/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.57it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=3/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.06it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=4/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.31it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=5/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.76it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=6/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.01it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=7/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.60it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=8/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.31it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=9/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.60it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=10/checkpoints/epoch19.mdl"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.52it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=11/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.95it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=12/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.91it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=13/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.95it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=14/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.64it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=15/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.39it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=16/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.28it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=17/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.87it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=18/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.13it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=19/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 19.86it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=20/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.78it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=21/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 24.98it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=22/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.02it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=23/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.83it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=24/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.90it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=25/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.31it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=26/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.61it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=27/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.75it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=28/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 27.13it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=29/checkpoints/epoch19.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 26.09it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 25.13it/s]\n"
     ]
    }
   ],
   "source": [
    "n = 75\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "n_S_seeds = 30\n",
    "epoch = 20 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'] ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "id": "970a6849-cb44-498f-add6-717f5b2d14ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_ece(preds, mask, dataset, n_bins=15, norm='l1'):\n",
    "    labels = [y for x, y in dataset]\n",
    "    labels = torch.tensor(labels).long()\n",
    "    indices = 2*np.arange(len(mask)) + mask\n",
    "    n_class = preds.shape[1]\n",
    "    metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=norm) ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "    ece = metrics(preds[indices],labels[indices])\n",
    "    return utils.to_numpy(ece)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "591c65ca-668d-4994-b88a-711192398ecc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(0.0024929, dtype=float32)"
      ]
     },
     "execution_count": 173,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_ece(preds=preds[0], mask=masks[0], dataset=all_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7509c3f2-41e9-4738-82ca-8d6da0baf921",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [y for x, y in all_examples]\n",
    "labels = torch.tensor(labels).long()\n",
    "indices = 2*np.arange(len(masks[0])) + masks[0] ## test dataset indices from supersamples\n",
    "#indices_train = 2*np.arange(len(masks[0])) + (1-masks[0]) ## train dataset indices from supersamples\n",
    "n_class = preds[0].shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6a469007-d235-4e97-89f6-caa48808e117",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'n_bins' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m metrics \u001b[38;5;241m=\u001b[39m ECE(num_classes\u001b[38;5;241m=\u001b[39mn_class, n_bins\u001b[38;5;241m=\u001b[39m\u001b[43mn_bins\u001b[49m, norm\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124ml1\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;66;03m## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#ece = metrics(preds[0][indices],labels[indices])\u001b[39;00m\n\u001b[1;32m      3\u001b[0m ece \u001b[38;5;241m=\u001b[39m metrics(torch\u001b[38;5;241m.\u001b[39msoftmax(preds[\u001b[38;5;241m0\u001b[39m][indices],\u001b[38;5;241m1\u001b[39m),labels[indices])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'n_bins' is not defined"
     ]
    }
   ],
   "source": [
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=\"l1\") ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "#ece = metrics(preds[0][indices],labels[indices])\n",
    "ece = metrics(torch.softmax(preds[0][indices],1),labels[indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d91b47-44a1-4636-a1ae-cf8991966289",
   "metadata": {},
   "outputs": [],
   "source": [
    "ece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "9926f224-0f61-4b50-843e-1d21e32be24a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#np.linspace(0, 1, n_bins + 1)\n",
    "#np.quantile(preds[0][indices], np.linspace(0, 1, n_bins + 1))\n",
    "#np.quantile(torch.softmax(preds[0][indices],1), np.linspace(0, 1, n_bins + 1))\n",
    "p = torch.softmax(preds[0][indices],1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "b4880702-16bb-4566-ad10-e2950210396a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.        , 0.99677184, 0.99800593, 0.99870667, 1.        ])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# uniform binning\n",
    "n_bins = 4\n",
    "#bins = np.linspace(0, 1, n_bins + 1)\n",
    "bins = np.quantile(torch.max(p, dim=1).values, np.linspace(0, 1, n_bins + 1))\n",
    "bins[0], bins[-1] = 0., 1.\n",
    "bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "e98c4148-f61c-4624-8f93-265c93b83962",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.max(preds[0][indices],dim=1).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "126c3a2d-2634-41dd-8edf-9ed8b85f6774",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "binids = digitize(torch.max(p, dim=1).values, bins)\n",
    "bin_total = np.bincount(binids, minlength=n_bins)\n",
    "bin_true = np.bincount(binids, weights=labels[indices], minlength=n_bins)\n",
    "bin_mean = interpolate_nan(bin_true / bin_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "bb7ecb01-7d3d-4724-a996-f2e927ddb7b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5       , 0.52631579, 0.10526316, 0.84210526, 0.52631579,\n",
       "       0.52631579, 0.84210526, 0.10526316, 0.10526316, 0.5       ,\n",
       "       0.52631579, 0.10526316, 0.84210526, 0.10526316, 0.52631579,\n",
       "       0.84210526, 0.84210526, 0.10526316, 0.10526316, 0.52631579,\n",
       "       0.84210526, 0.10526316, 0.5       , 0.84210526, 0.5       ,\n",
       "       0.84210526, 0.5       , 0.52631579, 0.5       , 0.10526316,\n",
       "       0.84210526, 0.10526316, 0.52631579, 0.5       , 0.52631579,\n",
       "       0.5       , 0.84210526, 0.5       , 0.5       , 0.52631579,\n",
       "       0.5       , 0.84210526, 0.5       , 0.84210526, 0.5       ,\n",
       "       0.52631579, 0.5       , 0.10526316, 0.5       , 0.5       ,\n",
       "       0.52631579, 0.10526316, 0.10526316, 0.52631579, 0.5       ,\n",
       "       0.10526316, 0.84210526, 0.52631579, 0.84210526, 0.10526316,\n",
       "       0.10526316, 0.52631579, 0.52631579, 0.52631579, 0.84210526,\n",
       "       0.84210526, 0.84210526, 0.84210526, 0.10526316, 0.52631579,\n",
       "       0.84210526, 0.52631579, 0.10526316, 0.5       , 0.10526316])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_mean[binids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "36178474-3553-4d3e-ad5e-31f512439c76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.10526316 0.10526316 0.52631579 0.52631579 0.10526316 0.10526316\n",
      " 0.10526316 0.5        0.10526316 0.5        0.10526316 0.10526316\n",
      " 0.10526316 0.10526316 0.5        0.52631579 0.10526316 0.10526316\n",
      " 0.10526316 0.10526316 0.52631579 0.10526316 0.52631579 0.84210526\n",
      " 0.10526316 0.10526316 0.10526316 0.10526316 0.5        0.10526316\n",
      " 0.52631579 0.10526316 0.10526316 0.10526316 0.10526316 0.10526316\n",
      " 0.5        0.5        0.52631579 0.10526316 0.5        0.84210526\n",
      " 0.10526316 0.52631579 0.52631579 0.10526316 0.84210526 0.10526316\n",
      " 0.5        0.10526316 0.5        0.52631579 0.10526316 0.10526316\n",
      " 0.52631579 0.10526316 0.52631579 0.10526316 0.84210526 0.10526316\n",
      " 0.52631579 0.84210526 0.84210526 0.84210526 0.10526316 0.10526316\n",
      " 0.10526316 0.10526316 0.10526316 0.10526316 0.52631579 0.10526316\n",
      " 0.5        0.10526316 0.5       ]\n",
      "tensor([0.9940, 0.9964, 0.9982, 0.9982, 0.9938, 0.9965, 0.9763, 0.9970, 0.9866,\n",
      "        0.9970, 0.9833, 0.9953, 0.9759, 0.9947, 0.9978, 0.9980, 0.9902, 0.9880,\n",
      "        0.9961, 0.9836, 0.9985, 0.9953, 0.9987, 0.9991, 0.7418, 0.9959, 0.8980,\n",
      "        0.9578, 0.9977, 0.9857, 0.9986, 0.9966, 0.9466, 0.8479, 0.9967, 0.9867,\n",
      "        0.9977, 0.9972, 0.9984, 0.9958, 0.9979, 0.9998, 0.9940, 0.9986, 0.9985,\n",
      "        0.9958, 0.9991, 0.8651, 0.9971, 0.9509, 0.9978, 0.9982, 0.8071, 0.9772,\n",
      "        0.9981, 0.9951, 0.9986, 0.9946, 0.9995, 0.9955, 0.9985, 0.9988, 0.9996,\n",
      "        0.9994, 0.9852, 0.9947, 0.9310, 0.9578, 0.9956, 0.7191, 0.9984, 0.9558,\n",
      "        0.9978, 0.9939, 0.9979])\n"
     ]
    }
   ],
   "source": [
    "#bin_true\n",
    "#binids\n",
    "#bin_mean[binids]\n",
    "#torch.max(p,1).values\n",
    "id = digitize(torch.max(torch.softmax(preds[2][indices],1), dim=1).values, bins)\n",
    "print(bin_mean[id])\n",
    "print(torch.max(torch.softmax(preds[2][indices],1),1).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "id": "101bddd2-2f56-4633-ab6b-6f23f46a56c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.25       0.5        0.125      1.         0.57142857 0.57142857\n",
      " 1.         0.25       0.14285714 0.28571429 0.5        0.125\n",
      " 1.         0.14285714 0.57142857 0.375      1.         0.125\n",
      " 0.25       0.375      1.         0.14285714 0.71428571 1.\n",
      " 0.25       1.         0.28571429 0.375      0.71428571 0.125\n",
      " 0.375      0.125      0.57142857 0.25       0.57142857 0.71428571\n",
      " 1.         0.71428571 0.28571429 0.57142857 0.28571429 1.\n",
      " 0.71428571 1.         0.71428571 0.57142857 0.28571429 0.125\n",
      " 0.25       0.71428571 0.5        0.25       0.125      0.375\n",
      " 0.28571429 0.14285714 1.         0.5        1.         0.14285714\n",
      " 0.14285714 0.5        0.5        0.5        1.         1.\n",
      " 0.375      0.375      0.14285714 0.5        1.         0.375\n",
      " 0.125      0.28571429 0.25      ]\n",
      "tensor([0.9968, 0.9981, 0.9948, 0.9993, 0.9982, 0.9984, 0.9995, 0.9968, 0.9955,\n",
      "        0.9973, 0.9981, 0.9931, 0.9994, 0.9950, 0.9983, 0.9988, 0.9989, 0.9914,\n",
      "        0.9966, 0.9984, 0.9991, 0.9964, 0.9980, 0.9996, 0.9970, 0.9991, 0.9978,\n",
      "        0.9986, 0.9979, 0.9950, 0.9987, 0.9948, 0.9983, 0.9971, 0.9981, 0.9978,\n",
      "        0.9998, 0.9979, 0.9974, 0.9981, 0.9976, 0.9994, 0.9978, 0.9992, 0.9978,\n",
      "        0.9981, 0.9976, 0.9950, 0.9969, 0.9978, 0.9981, 0.9968, 0.9914, 0.9987,\n",
      "        0.9975, 0.9963, 0.9989, 0.9980, 0.9994, 0.9952, 0.9954, 0.9981, 0.9980,\n",
      "        0.9980, 0.9994, 0.9991, 0.9988, 0.9988, 0.9962, 0.9981, 0.9995, 0.9984,\n",
      "        0.9940, 0.9978, 0.9967])\n"
     ]
    }
   ],
   "source": [
    "print(bin_mean[binids])\n",
    "print(torch.max(p,1).values)\n",
    "#print(torch.max(preds[0][indices],1).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3ec3b6bb-bdeb-48c2-b8af-2bcb43fad576",
   "metadata": {},
   "outputs": [],
   "source": [
    "def digitize(x, bins):\n",
    "    \"\"\"\n",
    "    Similar to np.digitize(x, bins, right=False) - 1, but the last bin unbounded on the right\n",
    "    e.g. bins = [0, 0.5, 1] -> bin 0: [0, 0.5), bin 1: [0.5, inf]\n",
    "    if the input is a number, return a number\n",
    "    \"\"\"\n",
    "    binids = np.minimum(np.digitize(x, bins), len(bins) - 1)\n",
    "    binids -= 1\n",
    "    if isinstance(x, (int, float)):\n",
    "        binids = binids.item()\n",
    "    return binids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6e69b6da-a92f-4d3a-9c11-39824ff98608",
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "594e1e11-d444-4343-a021-0b4f8def0c46",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "class BaseCalibrator:\n",
    "    def fit(self, z, y):\n",
    "        return self\n",
    "\n",
    "    def predict(self, z):\n",
    "        pass\n",
    "\n",
    "    def __call__(self, z):\n",
    "        return self.predict(z)\n",
    "\n",
    "    def plot(self, *args, ax=None, set_layout=None, **kwargs):\n",
    "        if ax is None:\n",
    "            if plt.get_fignums():\n",
    "                ax = plt.gca()\n",
    "            else:\n",
    "                _, ax = plt.subplots(figsize=(3.5, 3.5))\n",
    "\n",
    "        if hasattr(self, 'bins'):\n",
    "            ax.stairs(self.bin_mean, *args, edges=self.bins, baseline=None, **kwargs)\n",
    "            ax.set_ylim(bottom=-0.05)\n",
    "        elif hasattr(self, 'z_'):\n",
    "            ax.plot(self.z_, self.y_, *args, **kwargs)\n",
    "        else:\n",
    "            zz = np.linspace(0, 1, 1000)\n",
    "            ax.plot(zz, self.predict(zz), *args, **kwargs)\n",
    "\n",
    "        if set_layout:\n",
    "            ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)\n",
    "            plt.xlabel('Predicted Probability')\n",
    "            plt.ylabel('Empirical Frequency')\n",
    "            ax.set_aspect('equal', 'box')  # plt.axis('scaled'): ylim bottom change back to 0\n",
    "            plt.grid()\n",
    "            plt.legend(framealpha=0.4)\n",
    "        return ax\n",
    "\n",
    "class HistogramCalibrator(BaseCalibrator):\n",
    "    def __init__(self, n_bins=10, strategy='uniform'):\n",
    "        self.n_bins = n_bins\n",
    "        self.strategy = strategy\n",
    "        self.bins = None  # edges\n",
    "        self.bin_mean = None  # heights\n",
    "\n",
    "    def fit(self, z, y):\n",
    "        if self.strategy == 'uniform':\n",
    "            bins = np.linspace(0, 1, self.n_bins + 1)\n",
    "        elif self.strategy == 'quantile':\n",
    "            bins = np.quantile(z, np.linspace(0, 1, self.n_bins + 1))\n",
    "        else:\n",
    "            raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "        bins[0], bins[-1] = 0., 1.\n",
    "        binids = digitize(z, bins)\n",
    "        bin_total = np.bincount(binids, minlength=self.n_bins) ## the number of samples per bins\n",
    "        bin_true = np.bincount(binids, weights=y, minlength=self.n_bins) ## the number of samples per bins weighted by labels\n",
    "        with warnings.catch_warnings():\n",
    "            warnings.filterwarnings('ignore')\n",
    "            # fill nan by interpolation assuming smoothness\n",
    "            self.bin_mean = interpolate_nan(bin_true / bin_total) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "            self.frac_data = bin_total / len(z)\n",
    "            \n",
    "        self.bins = bins\n",
    "        return self\n",
    "\n",
    "    def predict(self, z):\n",
    "        binids = digitize(z, self.bins)\n",
    "        return self.bin_mean[binids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e2485a-0809-4ad6-a668-0fcfaa6784a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_total = torch.bincount(indices, minlength=len(bin_boundaries)-1) ## the number of samples per bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d6fddde7-48fa-4f60-9254-8b8111ad80ba",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'torch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m#calibrator = HistogramCalibrator(strategy='uniform')\u001b[39;00m\n\u001b[1;32m      2\u001b[0m calibrator \u001b[38;5;241m=\u001b[39m HistogramCalibrator(strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquantile\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m calibrator\u001b[38;5;241m.\u001b[39mfit(\u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mmax(p,\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues, labels[indices])\n\u001b[1;32m      4\u001b[0m torch\u001b[38;5;241m.\u001b[39mtensor(calibrator\u001b[38;5;241m.\u001b[39mpredict(torch\u001b[38;5;241m.\u001b[39mmax(p,\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
     ]
    }
   ],
   "source": [
    "#calibrator = HistogramCalibrator(strategy='uniform')\n",
    "calibrator = HistogramCalibrator(strategy='quantile')\n",
    "calibrator.fit(torch.max(p,1).values, labels[indices])\n",
    "torch.tensor(calibrator.predict(torch.max(p,1).values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2d763579-4ea5-46b6-9f64-35b6a760d028",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.125     , 0.14285714, 0.25      , 0.28571429, 0.71428571,\n",
       "       0.5       , 0.57142857, 0.375     , 1.        , 1.        ])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibrator.bin_mean\n",
    "#calibrator.bin_mean\n",
    "#calibrator.frac_data[binids]\n",
    "#calibrator.predict(torch.max(p,1).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "id": "bbff063d-0fa8-4168-876b-52180c5bd028",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2542, dtype=torch.float64)"
      ]
     },
     "execution_count": 190,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#((torch.tensor(calibrator.predict(torch.max(p,1).values)) - torch.max(p,1).values)**2).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "id": "ac3dba5a-979f-4fc8-bad8-8637c3258078",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,\n",
       "       9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,\n",
       "       9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,\n",
       "       9, 9, 9, 9, 9, 9, 9, 9, 9])"
      ]
     },
     "execution_count": 194,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "digitize(torch.max(p,1).values, calibrator.bins)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 250,
   "id": "190956cb-9244-4d74-aca8-a0c3c798e149",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(3.8232, dtype=torch.float64)"
      ]
     },
     "execution_count": 250,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#(torch.tensor(calibrator.frac_data[binids]) * torch.abs(torch.tensor(calibrator.predict(torch.max(p,1).values)) - torch.max(p,1).values)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "67c89eda-60d6-4a2d-b1de-592c8c11bf81",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.   , 0.125, 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 0.125, 1.   , 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 0.125, 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 1.   , 0.125, 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   , 1.   , 1.   , 1.   , 0.125, 1.   , 1.   ,\n",
       "       1.   , 1.   , 1.   ])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibrator.predict(torch.max(preds[4][indices],1).values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 260,
   "id": "1357803d-b8af-4132-9f15-601c4fd87c05",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.50087279, 0.93723812, 0.98829868, 0.99384503, 0.99569635,\n",
       "       0.99686411, 0.99779634, 0.99808111, 0.99862044, 0.9991127 ,\n",
       "       0.99992001])"
      ]
     },
     "execution_count": 260,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.quantile(torch.softmax(preds[0],1).max(1).values, np.linspace(0, 1, n_bins + 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e7942b-6ef0-4b93-952c-fd2b454682b8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
