{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "99f2c53a-6cd4-4191-a17e-b373c7abbe2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import softmax\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "%matplotlib inline\n",
    "\n",
    "#from modules.bound_utils import estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE\n",
    "\n",
    "from sklearn.feature_selection import mutual_info_classif, mutual_info_regression\n",
    "\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4f05137b-9353-4c89-8e2b-4751201d8a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data_utils import load_data_from_arguments\n",
    "from nnlib.nnlib.data_utils.wrappers import SubsetDataWrapper, LabelSubsetWrapper, ResizeImagesWrapper\n",
    "from nnlib.nnlib.data_utils.base import get_loaders_from_datasets, get_input_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "11ee4684-ba62-4555-90f4-c1ec6107fc3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b).float()\n",
    "\n",
    "def load_all_data(saved_data):\n",
    "    args = saved_data['args']\n",
    "    examples ,_ ,_ ,_ = load_data_from_arguments(args, build_loaders=False)\n",
    "    examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "    \n",
    "    return examples\n",
    "\n",
    "def preds_recalibrate(saved_data, all_data, n_data=100):\n",
    "    all_examples = saved_data['all_examples'] ## train, val data\n",
    "    args = saved_data['args']\n",
    "    #examples ,_ ,_ ,_ = load_data_from_arguments(args, build_loaders=False)\n",
    "    #examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "    \n",
    "    assert len(all_data) >= n_data\n",
    "    np.random.seed(args.seed)\n",
    "    include_indices = all_examples.include_indices\n",
    "    exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]\n",
    "    if n_data >= len(exclude_indices):\n",
    "        raise ValueError(f\" The number of recalibration data is larger than that of all data\")\n",
    "    exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)\n",
    "    examples = SubsetDataWrapper(all_data, include_indices=exclude_indices)\n",
    "    \n",
    "    return examples\n",
    "\n",
    "def calc_recalibration(ps, n_bins):\n",
    "    #idx = torch.bucketize(ps, n_bins, right=True) - 1\n",
    "    idx = idx_bins(ps, n_bins)\n",
    "    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins\n",
    "    bin_true = torch.bincount(idx, weights=ps, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins weighted by labels\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.filterwarnings('ignore')\n",
    "        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "    return bin_mean[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b325d8bc-9e9f-4ac1-b45c-53aad31d3195",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_softmax(preds):\n",
    "    \n",
    "    if not torch.all(torch.abs(torch.sum(preds, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        preds = preds.softmax(1)\n",
    "    else:\n",
    "        print(\"already softmax prob.\")\n",
    "    \n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 424,
   "id": "293fde71-3710-46e5-8cab-30496f18de6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=0/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:02<00:00, 15.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.30it/s]\n",
      "INFO:root:Dataset mnist has no validation set. Splitting the training set...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "calling function __init__ of MNIST with the following parameters:\n",
      "\tdata_augmentation: False\n",
      "calling function build_datasets of MNIST with the following parameters:\n",
      "\tdata_dir: None\n",
      "\tval_ratio: 0.2\n",
      "\tnum_train_examples: None\n",
      "\tseed: 0\n",
      "\tdownload: True\n",
      "Dataset mnist is loaded\n",
      "\ttrain_samples: 48000\n",
      "\tval_samples: 12000\n",
      "\ttest_samples: 10000\n",
      "\texample_shape: torch.Size([1, 28, 28])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 38.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 37.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=1/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.47it/s]\n",
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 36.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 39.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=2/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.57it/s]\n",
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 36.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 35.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=3/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.44it/s]\n",
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 38.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 37.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=4/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 32/32 [00:01<00:00, 16.50it/s]\n",
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 37.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1/1 [00:00<00:00, 24.25it/s]\n"
     ]
    }
   ],
   "source": [
    "#n = 75\n",
    "#n = 250\n",
    "n = 4000\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "#n_S_seeds = 30\n",
    "n_S_seeds = 5\n",
    "epoch = 200 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "#recalibrate_data = False\n",
    "recalibrate_data = True\n",
    "if recalibrate_data:\n",
    "    preds_recab = []\n",
    "    labels_recab = []\n",
    "\n",
    "#n_bins = torch.tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) ## Uniform\n",
    "#n_bins = int(n ** (1/4)) ## l1\n",
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "bins = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = check_softmax(utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred']) ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "    cur_bins = torch.bucketize(cur_preds[train_idx].softmax(1).max(1).values, n_bins, right=True) - 1\n",
    "    #id = np.arange(0,len(torch.unique(cur_labels)))\n",
    "    #cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'].softmax(1)\n",
    "    #cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    #a = torch.zeros_like(cur_preds)\n",
    "    #a[:,0] = cur_preds[:,id!=1].sum(1)\n",
    "    #a[:,1] = cur_preds[:,id==1].sum(1)\n",
    "    #cur_preds = a\n",
    "    #cur_bins = torch.bucketize(cur_preds[train_idx].softmax(1).max(1).values, n_bins, right=True) - 1\n",
    "    \n",
    "    if recalibrate_data:\n",
    "        if S_seed == 0: ## Loading the all dataset to prepare the recablation dataset at the first iteration\n",
    "            all_data = load_all_data(saved_data=saved_data)\n",
    "        recab_examples = preds_recalibrate(saved_data=saved_data, all_data=all_data, n_data=100)\n",
    "        cur_preds_recab = check_softmax(utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=n_batch)['pred']) ## (2*n, num_classes)\n",
    "        cur_labels_recab = utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "        preds_recab.append(cur_preds_recab)\n",
    "        labels_recab.append(cur_labels_recab)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)\n",
    "    bins.append(cur_bins.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 425,
   "id": "401b1dd6-79fa-4d35-b432-afb8ce28bf7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "class _LabelLoss(nn.Module):\n",
    "    def __init__(self, n_bins=15, method='uniform', logits=None):\n",
    "        \"\"\"\n",
    "        n_bins (int): number of confidence interval bins\n",
    "        \"\"\"\n",
    "        super(_LabelLoss, self).__init__()\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, n_bins + 1)\n",
    "        elif method == 'quantile':\n",
    "            conf, pred = logits.max(1)\n",
    "            conf[pred==0] = 1 - conf[pred==0]\n",
    "            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))\n",
    "        \n",
    "        bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        self.bin_lowers = bin_boundaries[:-1]\n",
    "        self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    def forward(self, logits, labels, masks):\n",
    "        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):\n",
    "            print(\"make softmax prob.\")\n",
    "            softmaxes = F.softmax(logits, dim=1)\n",
    "        else:\n",
    "            softmaxes = logits\n",
    "\n",
    "        train_idx = 2*np.arange(len(masks)) + masks\n",
    "        test_idx = 2*np.arange(len(1-masks)) + (1-masks)\n",
    "        pred_tr, pred_te = logits[train_idx], logits[test_idx]\n",
    "        ls_tr, ls_te = labels[train_idx], labels[test_idx]\n",
    "        \n",
    "        conf_tr, pl_tr = torch.max(pred_tr, 1)\n",
    "        conf_tr[pl_tr==0] = 1 - conf_tr[pl_tr==0]\n",
    "        conf_te, pl_te = torch.max(pred_te, 1)\n",
    "        conf_te[pl_te==0] = 1 - conf_tr[pl_te==0]\n",
    "        \n",
    "        acc = torch.zeros(1, device=pred_tr.device)\n",
    "        bin_loss = torch.zeros(1, device=pred_tr.device)\n",
    "        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = conf_tr.gt(bin_lower.item()) * conf_tr.le(bin_upper.item())\n",
    "            in_bin_te = conf_te.gt(bin_lower.item()) * conf_te.le(bin_upper.item())\n",
    "            prop_in_bin = in_bin.float().mean()\n",
    "            prop_in_bin_te = in_bin_te.float().mean()\n",
    "            if prop_in_bin.item() > 0 and prop_in_bin_te.item() > 0:\n",
    "                freq_tr = ls_tr[in_bin].float()\n",
    "                freq_te = ls_te[in_bin_te].float()\n",
    "                acc += abs((freq_te.float().sum() - freq_tr.float().sum())/ n)\n",
    "                bin_loss += prop_in_bin_te\n",
    "            elif prop_in_bin.item() > 0:\n",
    "                freq_tr = ls_tr[in_bin].float()\n",
    "                acc += freq_tr.float().sum() / n\n",
    "            elif prop_in_bin_te.item() > 0:\n",
    "                freq_te = ls_te[in_bin_te].float()\n",
    "                acc += freq_te.float().sum() / n\n",
    "                bin_loss += prop_in_bin_te\n",
    "            else:\n",
    "                acc += 0\n",
    "\n",
    "        return (acc, bin_loss)\n",
    "\n",
    "def compute_lsloss(preds, mask, dataset, loss, target_label=None):\n",
    "    labels = [y for x, y in dataset]\n",
    "    labels = torch.tensor(labels).long()\n",
    "    if target_label is not None:\n",
    "        labels[labels!=target_label] = 0\n",
    "        labels[labels==target_label] = 1\n",
    "    \n",
    "    acc_loss, bin_loss = loss.forward(preds, labels, mask)\n",
    "    \n",
    "    return (utils.to_numpy(acc_loss), utils.to_numpy(bin_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31912911-241c-4380-ba60-a3cda5a11810",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 426,
   "id": "ff04e2e4-da1e-4eeb-b432-f3c8b847d03e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "B = 10\n",
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "pred_tr, pred_te = cur_preds[train_idx], cur_preds[test_idx]\n",
    "ls = [y for x, y in all_examples]\n",
    "ls = torch.tensor(ls).long()\n",
    "ls_tr, ls_te = ls[train_idx], ls[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 427,
   "id": "afb3fbae-c29f-4a61-9b8c-bb73cd04557f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([0.04525], dtype=float32), array([0.7845], dtype=float32))"
      ]
     },
     "execution_count": 427,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = _LabelLoss(n_bins=B, method='quantile', logits=pred_tr)\n",
    "ecmi = compute_lsloss(cur_preds, cur_mask, all_examples, loss)\n",
    "#utils.to_numpy(loss.forward(cur_preds, cur_labels, cur_mask))\n",
    "ecmi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b317adad-6249-4b07-9fbf-cb45acc31c73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 428,
   "id": "a229e483-a656-4547-8b5b-2c3316704733",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "class CalibrationMethod(sklearn.base.BaseEstimator):\n",
    "    \"\"\"\n",
    "    A generic class for probability calibration\n",
    "\n",
    "    A calibration method takes a set of posterior class probabilities and transform them into calibrated posterior\n",
    "    probabilities. Calibrated in this sense means that the empirical frequency of a correct class prediction matches its\n",
    "    predicted posterior probability.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \"\"\"\n",
    "        Fit the calibration method based on the given uncalibrated class probabilities X and ground truth labels y.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            Training data, i.e. predicted probabilities of the base classifier on the calibration set.\n",
    "        y : array-like, shape (n_samples,)\n",
    "            Target classes.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        self : object\n",
    "            Returns an instance of self.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError(\"Subclass must implement this method.\")\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Compute calibrated posterior probabilities for a given array of posterior probabilities from an arbitrary\n",
    "        classifier.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            The uncalibrated posterior probabilities.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        P : array, shape (n_samples, n_classes)\n",
    "            The predicted probabilities.\n",
    "        \"\"\"\n",
    "        raise NotImplementedError(\"Subclass must implement this method.\")\n",
    "\n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predict the class of new samples after scaling. Predictions are identical to the ones from the uncalibrated\n",
    "        classifier.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            The uncalibrated posterior probabilities.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        C : array, shape (n_samples,)\n",
    "            The predicted classes.\n",
    "        \"\"\"\n",
    "        return np.argmax(self.predict_proba(X), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 429,
   "id": "25c3f712-aa50-4aed-82e9-461da7887c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import pairwise_distances\n",
    "\n",
    "class NadarayaWatson(CalibrationMethod):\n",
    "    def __init__(self, kernel='dirichlet'):\n",
    "        super().__init__()\n",
    "        self.kernel = kernel\n",
    "    \n",
    "    def fit(self, X):\n",
    "        if self.kernel=='dirichlet':\n",
    "            bandwidth = self.get_bandwidth(X)\n",
    "        else:\n",
    "            bandwidth = 0.1 ## That of Gaussian kernel is optimized as the median of pairwise distance\n",
    "        self.log_kern = self.get_kernel(X, bandwidth)\n",
    "        self.kern = torch.exp(self.log_kern).to(torch.float32)\n",
    "\n",
    "    def predict_proba(self, y):\n",
    "        y_onehot = nn.functional.one_hot(y, num_classes=len(y.unique())).to(torch.float32)\n",
    "        kern_y = torch.matmul(self.kern, y_onehot)\n",
    "        den = torch.sum(self.kern, dim=1)\n",
    "        \n",
    "        # to avoid division by 0\n",
    "        den = torch.clamp(den, min=1e-10)\n",
    "        \n",
    "        ratio = kern_y / den.unsqueeze(-1)\n",
    "\n",
    "        return ratio\n",
    "\n",
    "    def get_kernel(self, f, bandwidth):\n",
    "        if self.kernel == 'dirichlet':\n",
    "            if f.shape[1] == 1:\n",
    "                log_kern = self.beta_kernel(f, f, bandwidth).squeeze()\n",
    "            else:\n",
    "                log_kern = self.dirichlet_kernel(f, bandwidth).squeeze()\n",
    "        elif self.kernel == 'gaussian':\n",
    "            log_kern = self.gaussian_kernel(f, bandwidth).squeeze()\n",
    "        \n",
    "        # Trick: -inf on the diagonal\n",
    "        return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f)))\n",
    "\n",
    "    def gaussian_kernel(self, z, bandwidth=0.1, median=True):\n",
    "        Dxx = pairwise_distances(z.numpy(), metric='sqeuclidean')\n",
    "        if median:\n",
    "            bandwidth = np.median(Dxx)\n",
    "        log_gauss_pdf = -Dxx / 2 / bandwidth\n",
    "        \n",
    "        return torch.tensor(log_gauss_pdf).float()\n",
    "\n",
    "    def beta_kernel(self, z, zi, bandwidth=0.1):\n",
    "        p = zi / bandwidth + 1\n",
    "        q = (1-zi) / bandwidth + 1\n",
    "        z = z.unsqueeze(-2)\n",
    "        \n",
    "        log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)\n",
    "        log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)\n",
    "        log_beta_pdf = log_num - log_beta\n",
    "\n",
    "        return log_beta_pdf\n",
    "\n",
    "    def dirichlet_kernel(self, z, bandwidth=0.1):\n",
    "        alphas = z / bandwidth + 1\n",
    "        \n",
    "        log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))\n",
    "        log_num = torch.matmul(torch.log(z), (alphas-1).T)\n",
    "        log_dir_pdf = log_num - log_beta\n",
    "\n",
    "        return log_dir_pdf\n",
    "\n",
    "    def get_bandwidth(self, f):\n",
    "        \"\"\"\n",
    "        Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE).\n",
    "        \n",
    "        :param f: The vector containing the probability scores, shape [num_samples, num_classes]\n",
    "        :param device: The device type: 'cpu' or 'cuda'\n",
    "        :return: The bandwidth of the kernel\n",
    "        \"\"\"\n",
    "        bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5)))\n",
    "        max_b = -1\n",
    "        max_l = 0\n",
    "        n = len(f)\n",
    "        for b in bandwidths:\n",
    "            log_kern = self.get_kernel(f, b)\n",
    "            log_fhat = torch.logsumexp(log_kern, 1) - np.log(n-1)\n",
    "            l = torch.sum(log_fhat)\n",
    "            if l > max_l:\n",
    "                max_l = l\n",
    "                max_b = b\n",
    "\n",
    "        return max_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 430,
   "id": "89f02fc1-8ad6-41a0-ac70-86fda889d9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "pred_tr, pred_te = cur_preds[train_idx], cur_preds[test_idx]\n",
    "ls = [y for x, y in all_examples]\n",
    "ls = torch.tensor(ls).long()\n",
    "ls_tr, ls_te = ls[train_idx], ls[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 431,
   "id": "13106ae3-cb80-47a9-b909-574e82319ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec_tr = NadarayaWatson(kernel='gaussian')\n",
    "#rec_f = NadarayaWatson()\n",
    "rec_tr.fit(pred_tr)\n",
    "NW_tr_tr = rec_tr.predict_proba(ls_tr)\n",
    "NW_tr_te = rec_tr.predict_proba(ls_te)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 432,
   "id": "c3a49e9c-86bc-45ea-9900-ec0d307e1a86",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec_te = NadarayaWatson(kernel='gaussian')\n",
    "#rec_f = NadarayaWatson()\n",
    "rec_te.fit(pred_te)\n",
    "NW_te_tr = rec_te.predict_proba(ls_tr)\n",
    "NW_te_te = rec_te.predict_proba(ls_te)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 433,
   "id": "9b756a7a-6e7d-40af-abc1-c27ea61a3f5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8000, 2])"
      ]
     },
     "execution_count": 433,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.concat([NW_tr_tr, NW_tr_te]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 434,
   "id": "83a2309f-58dd-4326-8811-63213293e06b",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_onehot = nn.functional.one_hot(cur_labels, num_classes=len(cur_labels.unique())).to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 435,
   "id": "22cf2732-4808-4eda-a23a-f577b25b7b71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8000, 2])"
      ]
     },
     "execution_count": 435,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#NW_tr_tr.shape\n",
    "y_onehot.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 436,
   "id": "af62a909-691b-41c4-9f60-fcbd926b1185",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-44.9447)"
      ]
     },
     "execution_count": 436,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(y_onehot[train_idx] - NW_tr_tr) - torch.norm(y_onehot[train_idx] - NW_tr_te)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c46ac2f-0d5b-4fa7-b897-fd2f3f7ec496",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 437,
   "id": "b11b91b3-ada6-4c66-a5ba-f846f415ef2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.norm(ls[0] - ps[0]) - torch.norm(ls[0] - ps[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 438,
   "id": "97b4ae19-28cd-45d6-8401-2390a3c75acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_NWpreds_for_bound(preds, labels, masks, kernel='gaussian'):\n",
    "    NW_tr_preds = []\n",
    "    NW_te_preds = []\n",
    "    for i in range(len(preds)):\n",
    "        NW_tr, NW_te = NadarayaWatson(kernel), NadarayaWatson(kernel)\n",
    "        inv_mask = 1 - masks[i]\n",
    "        tr_idx, te_idx = 2*np.arange(len(masks[i])) + masks[i], 2*np.arange(len(inv_mask)) + inv_mask\n",
    "        pred_tr, pred_te = preds[i][tr_idx], preds[i][te_idx]\n",
    "        ls_tr, ls_te = labels[i][tr_idx], labels[i][te_idx]\n",
    "        # Fit & predict\n",
    "        NW_tr.fit(pred_tr)\n",
    "        NW_tr_tr, NW_tr_te = NW_tr.predict_proba(ls_tr), NW_tr.predict_proba(ls_te)\n",
    "        NW_tr_preds.append(torch.concat([NW_tr_tr, NW_tr_te]))\n",
    "        NW_te.fit(pred_te)\n",
    "        NW_te_tr, NW_te_te = NW_te.predict_proba(ls_tr), NW_te.predict_proba(ls_te)\n",
    "        NW_te_preds.append(torch.concat([NW_te_tr, NW_te_te]))\n",
    "\n",
    "    return preds, NW_tr_preds, NW_te_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 544,
   "id": "01c4f25e-8d7c-479e-b187-55be1122fdda",
   "metadata": {},
   "outputs": [],
   "source": [
    "#preds, NW_tr_preds, NW_te_preds = get_NWpreds_for_bound(preds, labels, masks, kernel='dirichlet')\n",
    "preds, NW_tr_preds, NW_te_preds = get_NWpreds_for_bound(preds, labels, masks, kernel='gaussian')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 538,
   "id": "8430f3f4-c2ab-45bb-b070-a9d099ee24c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f09381ebf34c408386589bcc42ae851d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_classes = len(labels[0].unique())\n",
    "num_examples = n\n",
    "list_of_mis = []\n",
    "bound = 0.0\n",
    "for idx in tqdm(range(num_examples)):\n",
    "    ms = np.array([p[idx] for p in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ps_tr = [p[2*idx:2*idx+2] for p in NW_tr_preds]\n",
    "    ps_te = [p[2*idx:2*idx+2] for p in NW_te_preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "\n",
    "    loss = torch.zeros(len(ps), 1)\n",
    "    ls_fcmi = torch.zeros(len(ps), 3*num_classes, num_classes)\n",
    "    for i in range(len(ps)):\n",
    "        loss[i] = torch.norm(ls[i] - ps_tr[i]) - torch.norm(ls[i] - ps_te[i])\n",
    "        ls_fcmi[i] = torch.concat([ps[i], ps_tr[i], ps_te[i]])\n",
    "    cur_mi = max(0., mutual_info_classif(loss, ms, discrete_features=[False]).sum())\n",
    "    cur_fcmi = max(0., mutual_info_classif(ls_fcmi.reshape(-1, 3*num_classes*num_classes), ms, discrete_features=[False]*3*num_classes*num_classes).sum())\n",
    "    list_of_mis.append(np.sqrt(2*cur_mi) + 4*np.sqrt(2*cur_fcmi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 540,
   "id": "c367e26f-8676-4701-bcda-c2e008e07d02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.861963744653054"
      ]
     },
     "execution_count": 540,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(list_of_mis).sum()/num_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 545,
   "id": "b6121ce5-711f-40bb-9821-ebade768099d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a336a37c15ed4c518e7c08a766a5f10f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "list_of_mis = []\n",
    "ecmi_list = []\n",
    "fcmi_list = []\n",
    "for idx in tqdm(range(num_examples)):\n",
    "    ms = np.array([p[idx] for p in masks])\n",
    "    ps = [p[idx] for p in preds]\n",
    "    ps_tr = [p[idx] for p in NW_tr_preds]\n",
    "    ps_te = [p[idx] for p in NW_te_preds]\n",
    "    ls = [l[idx] for l in labels]\n",
    "\n",
    "    ls_fcmi = torch.concat([torch.concat([ps[i],ps_tr[i],ps_te[i]]) for i in range(len(ps))]).reshape(-1, 3*num_classes)\n",
    "    \n",
    "    loss = torch.zeros(len(ps), 1)\n",
    "    for i in range(len(ps)):\n",
    "        loss[i] = torch.norm(ls[i] - ps_tr[i]) - torch.norm(ls[i] - ps_te[i])\n",
    "    cur_mi = max(0., mutual_info_classif(loss, ms, discrete_features=[False]).sum())\n",
    "    cur_fcmi = max(0., mutual_info_classif(ls_fcmi, ms, discrete_features=[False]*3*num_classes).sum())\n",
    "    ecmi_list.append(cur_mi)\n",
    "    fcmi_list.append(cur_fcmi)\n",
    "    list_of_mis.append(np.sqrt(2*cur_mi) + 4*np.sqrt(2*cur_fcmi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 546,
   "id": "1529c32f-8f0a-4761-a50a-5d0d57d01bd5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.377173405430758"
      ]
     },
     "execution_count": 546,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.array(fcmi_list).sum()\n",
    "#np.array(ecmi_list).sum()\n",
    "np.array(list_of_mis).sum() / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 503,
   "id": "3341f663-3ddf-41b6-a0ac-55412be45cdc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(15,)"
      ]
     },
     "execution_count": 503,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.tile(ms, 3).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 517,
   "id": "81025778-a848-4354-a823-3d2e74934d2d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0])"
      ]
     },
     "execution_count": 517,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#ls_fcmi.reshape(-1,2)\n",
    "mutual_info_classif(ls_fcmi.reshape(-1,2), np.tile(ms, 3), discrete_features=[False, False])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 485,
   "id": "bb76eb3f-e422-4984-b84b-457450eee77a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([15, 2])"
      ]
     },
     "execution_count": 485,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ls_fcmi.reshape(-1,2).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 497,
   "id": "c3a8143f-aa1d-4b9d-b8ba-a87ba4d1211d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30,)"
      ]
     },
     "execution_count": 497,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.tile(ms, 3*num_classes).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe3a2c98-09c2-414e-ad9a-2b940f8379a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ls"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
