{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3158ddea-9aba-4448-984e-4710e11b8105",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE\n",
    "\n",
    "from sklearn.feature_selection import mutual_info_classif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "86918b25-7afb-46c7-a390-7b5f0250f360",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=0/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 18.11it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 19.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=1/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.49it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=2/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.97it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=3/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.47it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=4/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.31it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=5/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.93it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=6/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.06it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=7/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.42it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=8/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.26it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=9/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.97it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=10/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.81it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 19.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=11/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.38it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=12/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.15it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 23.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=13/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.75it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=14/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.42it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=15/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.58it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=16/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.94it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=17/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.67it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=18/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.27it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=19/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.32it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=20/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.69it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=21/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.89it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=22/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.96it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=23/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.64it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=24/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.85it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=25/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.25it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=26/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 23.21it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=27/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.91it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=28/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 22.28it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=75,lr=0.001,seed=0,S_seed=29/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 20.88it/s]\n",
      "100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 21.61it/s]\n"
     ]
    }
   ],
   "source": [
    "n = 75\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "n_S_seeds = 30\n",
    "epoch = 200 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'] ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e621585d-59ee-447c-91c4-242928afe43c",
   "metadata": {},
   "source": [
    "# Code analysis of CE calcuration in torchmetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ab5d2794-cfd4-4fe5-9045-8ca198fbb37f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.2100e-05)\n",
      "tensor(1.2100e-05)\n",
      "tensor(1.2100e-05)\n"
     ]
    }
   ],
   "source": [
    "labels = [y for x, y in all_examples]\n",
    "labels = torch.tensor(labels).long()\n",
    "indices = 2*np.arange(len(masks[0])) + masks[0] ## train dataset indices from supersamples\n",
    "indices_test = 2*np.arange(len(masks[0])) + (1-masks[0])\n",
    "n_class = preds[0].shape[1]\n",
    "n_bins = 4\n",
    "\n",
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=\"l1\") ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)\n",
    "\n",
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm=\"max\") ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)\n",
    "\n",
    "metrics = ECE(num_classes=n_class, n_bins=n_bins, norm='l2') ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "ece = metrics(preds[0][indices],labels[indices]) ## softmax is impremented in the inside of ECE function.\n",
    "print(ece)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cf9911e0-1fde-4d20-b3a2-99c32697d0ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 以降，データを以下に限定\n",
    "p = preds[0][indices]\n",
    "l = labels[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701f6d8d-14f4-4975-994d-dae55d004a34",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a987cc36-1f26-4985-b014-3befb3dbdfb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "## confution matrixの作成\n",
    "convert_to_labels = False\n",
    "\n",
    "# Apply argmax if we have one more dimension\n",
    "if p.ndim == l.ndim + 1 and convert_to_labels:\n",
    "    p = p.argmax(dim=1)\n",
    "p = p.flatten() if convert_to_labels else torch.movedim(p, 1, -1).reshape(-1, p.shape[1])\n",
    "l = l.flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "47204cc2-d8f5-4337-9ac7-36bc253c3f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import softmax\n",
    "\n",
    "## 確率ベクトルへの変換．\n",
    "if not torch.all((p >= 0) * (p <= 1)):\n",
    "    p = p.softmax(1)\n",
    "confidences, predictions = p.max(dim=1)\n",
    "accuracies = predictions.eq(l)\n",
    "confidences = confidences.float()\n",
    "accuracies = accuracies.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c31af12-8782-4c9f-8134-a77ea0a8d9f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "32c88c1b-088f-42bd-8179-2c40780069a5",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'torch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_88324/2427184297.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_bins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'uniform'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m         \u001b[0mn_bins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_bins\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfidences\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfidences\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m## Uniform binning TODO: implementing Uniform-mass binning\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m         \u001b[0mn_bins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_bins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'quantile'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
     ]
    }
   ],
   "source": [
    "## bin\n",
    "n_bins = 4\n",
    "#method = 'quantile'\n",
    "method = 'uniform'\n",
    "\n",
    "if isinstance(n_bins, int):\n",
    "    if method == 'uniform':\n",
    "        n_bins = torch.linspace(0, 1, n_bins + 1, dtype=confidences.dtype, device=confidences.device) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    elif method == 'quantile':\n",
    "        #pass\n",
    "        n_bins = torch.tensor(np.quantile(confidences, np.linspace(0, 1, n_bins + 1)))\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    else:\n",
    "        raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "        \n",
    "with torch.no_grad():\n",
    "    #acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries)\n",
    "    ## binごとのaccなどを算出\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    acc_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    label_bin = torch.zeros(len(n_bins), device=l.device, dtype=l.dtype)\n",
    "\n",
    "    idx = torch.bucketize(confidences, n_bins, right=True) - 1\n",
    "    count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "    conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=idx, src=accuracies)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()\n",
    "    \n",
    "    label_bin.scatter_add_(dim=0, index=idx, src=l)\n",
    "    l_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "\n",
    "## ECE\n",
    "print(torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin))\n",
    "print(torch.max(torch.abs(acc_bin - conf_bin)))\n",
    "print(torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)))\n",
    "\n",
    "print(torch.sum(torch.abs(l_bin - conf_bin) * prop_bin))\n",
    "print(torch.max(torch.abs(l_bin - conf_bin)))\n",
    "print(torch.sqrt(torch.sum(torch.pow(l_bin - conf_bin, 2) * prop_bin)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "f1d10bd6-614e-4a72-8d9d-1dc51b96b3c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.,  0.,  0., 75.,  0.])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#len(masks[0])\n",
    "#len(p)\n",
    "#idx\n",
    "count_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b6614cac-8cd0-4ed9-befd-7cd06a4ff6f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_total = torch.bincount(indices, minlength=len(n_bins)-1).to(confidences.device) ## the number of samples per bins\n",
    "bin_true = (torch.bincount(indices, weights=l, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "#bin_true = torch.tensor(bin_true, dtype=confidences.dtype, device=confidences.device)\n",
    "a = bin_true / bin_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a0d0f916-8c0e-4e05-9f8c-914c19c06089",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'torch' has no attribute 'interp'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_4991/1945517639.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mnans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnans\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minterp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnans\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m~\u001b[0m\u001b[0mnans\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m~\u001b[0m\u001b[0mnans\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: module 'torch' has no attribute 'interp'"
     ]
    }
   ],
   "source": [
    "b = a.clone()\n",
    "nans = torch.isnan(b)\n",
    "i = torch.arange(len(b))\n",
    "b[nans] = torch.interp(i[nans], i[~nans], b[~nans])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5fa8097c-abe4-42ee-acaf-2777cbba7164",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 1 2 3 4 5 6 7 8 9]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.float32"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = np.copy(a)\n",
    "nans = np.isnan(b)\n",
    "i = np.arange(len(b))\n",
    "print(i)\n",
    "b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "a = torch.tensor(b)\n",
    "a.dtype\n",
    "#torch.isnan(b)\n",
    "#print(bin_total,bin_true)\n",
    "#print(torch.bincount(indices,minlength=10))\n",
    "#print(torch.bincount(indices, weights=l, minlength=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "de2bda87-c59c-4ea7-8904-41c541efa4cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#a[indices]\n",
    "prop_bin\n",
    "#torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e4c832ea-2bc2-465d-82d5-d215cdbc442f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.2100e-05)\n",
      "tensor(1.2100e-05)\n",
      "tensor(1.2100e-05)\n"
     ]
    }
   ],
   "source": [
    "## ECE\n",
    "print(torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin))\n",
    "print(torch.max(torch.abs(acc_bin - conf_bin)))\n",
    "print(torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ccd9c76c-3bd6-42d4-b7b2-2c8c84f604a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6, 4, 5, 8, 5, 4, 9, 1, 0, 0, 8, 1, 9, 1, 3, 9, 7, 0, 1, 5, 6, 2, 8, 9,\n",
       "        2, 7, 7, 8, 4, 0, 9, 2, 5, 4, 0, 7, 9, 1, 8, 7, 6, 8, 4, 7, 0, 7, 3, 2,\n",
       "        6, 6, 5, 4, 4, 4, 5, 5, 4, 2, 8, 0, 1, 2, 6, 3, 9, 6, 9, 1, 3, 4, 8, 5,\n",
       "        0, 2, 2])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.bucketize(confidences, n_bins, right=True) - 1\n",
    "#prop_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "b23e77cd-d17e-4936-9a9f-b73a9e9e2c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import softmax\n",
    "\n",
    "def multiclass_calibration_error(preds, target, num_classes, n_bins=15, ignore_index=False, norm='l1', method='uniform'):\n",
    "    \"\"\"\n",
    "    Final code to compute multiclass calibration error\n",
    "    \"\"\"\n",
    "    preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False)\n",
    "    confidences, accuracies = _multiclass_calibration_error_update(preds, target)\n",
    "    return _ce_compute(confidences, accuracies, n_bins, norm, method)\n",
    "\n",
    "def _multiclass_confusion_matrix_format(preds, target, ignore_index=None, convert_to_labels=True):\n",
    "    \"\"\"\n",
    "    Formatting the predictive values and labels to calculate the calibration error\n",
    "    \"\"\"\n",
    "    # Apply argmax if we have one more dimension\n",
    "    if preds.ndim == target.ndim + 1 and convert_to_labels:\n",
    "        preds = preds.argmax(dim=1)\n",
    "    preds = preds.flatten() if convert_to_labels else torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1])\n",
    "    target = target.flatten()\n",
    "    \n",
    "    if ignore_index is not None:\n",
    "        idx = target != ignore_index\n",
    "        preds = preds[idx]\n",
    "        target = target[idx]\n",
    "    \n",
    "    return preds, target\n",
    "\n",
    "def _multiclass_calibration_error_update(preds, target):\n",
    "    \"\"\"\n",
    "    Preparing the predictive probability vectors and accuracies for CE calculation\n",
    "    \"\"\"\n",
    "    if not torch.all((preds >= 0) * (preds <= 1)):\n",
    "        preds = preds.softmax(1) ## convert to prob. vector if preds are not probabilistic outputs\n",
    "    confidences, predictions = preds.max(dim=1)\n",
    "    accuracies = predictions.eq(target)\n",
    "    return confidences.float(), accuracies.float()\n",
    "\n",
    "def _binning_bucketize(confidences, accuracies, bin_boundaries):\n",
    "    \"\"\"\n",
    "    Calculate the contents count, accuracies, confidences, and proportions per bins.\n",
    "    \"\"\"\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    acc_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(bin_boundaries), device=confidences.device, dtype=confidences.dtype)\n",
    "\n",
    "    indices = torch.bucketize(confidences, bin_boundaries, right=True) - 1\n",
    "    count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))\n",
    "    conf_bin.scatter_add_(dim=0, index=indices, src=confidences)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "    return acc_bin, conf_bin, prop_bin\n",
    "\n",
    "def _ce_compute(confidences, accuracies, bin_boundaries, norm='l1', method='uniform'):\n",
    "    \"\"\"\n",
    "    Computing CE values\n",
    "    \"\"\"\n",
    "    ## bin setting (uniform or uniform-mass)\n",
    "    if isinstance(bin_boundaries, int):\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=confidences.dtype, device=confidences.device)\n",
    "            print(bin_boundaries)\n",
    "        elif method == 'quantile':\n",
    "            bin_boundaries = torch.tensor(np.quantile(confidences, np.linspace(0, 1, bin_boundaries + 1)))\n",
    "            bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        else:\n",
    "            raise ValueError('strategy should be either \"uniform\" or \"quantile\".')\n",
    "\n",
    "    with torch.no_grad():\n",
    "        acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries)\n",
    "\n",
    "    if norm == \"l1\":\n",
    "        ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin)\n",
    "    if norm == \"max\":\n",
    "        ce = torch.max(torch.abs(acc_bin - conf_bin))\n",
    "    if norm == \"l2\":\n",
    "        ce = torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin))\n",
    "\n",
    "    return ce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "7425ca73-de27-47ba-b998-243ee92e07bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0119)\n",
      "tensor(0.0355)\n",
      "tensor(0.0205)\n"
     ]
    }
   ],
   "source": [
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l1',method='uniform'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='uniform'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='uniform'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "c5780bf1-a3db-46d2-a9f0-73ee373c3900",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0119)\n",
      "tensor(0.0355)\n",
      "tensor(0.0205)\n"
     ]
    }
   ],
   "source": [
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l1',method='quantile'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='max',method='quantile'))\n",
    "print(multiclass_calibration_error(preds[0][indices],labels[indices], num_classes=n_class, n_bins=n_bins, norm='l2',method='quantile'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98701cf5-68e1-4da9-9932-fb561699fae4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
