{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a13b154c-60f2-4136-8268-ba5a5cb71d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import softmax\n",
    "\n",
    "from nnlib.nnlib import utils\n",
    "from nnlib.nnlib.matplotlib_utils import import_matplotlib\n",
    "\n",
    "matplotlib, plt = import_matplotlib()\n",
    "%matplotlib inline\n",
    "\n",
    "#from modules.bound_utils import estimate_fcmi_bound_classification, estimate_default_bound_classification, estimate_proposed_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from modules.bound_utils import estimate_fcmi_bound_classification, estimate_sgld_bound, estimate_kl_bound_classification, estimate_lg_bound_classification, estimate_interp_bound_classification\n",
    "from scripts.fcmi_train_classifier import mnist_ld_schedule, \\\n",
    "    cifar_resnet50_ld_schedule  # for pickle to be able to load LD methods\n",
    "import methods\n",
    "from torchmetrics.classification import MulticlassCalibrationError as ECE\n",
    "\n",
    "from sklearn.feature_selection import mutual_info_classif, mutual_info_regression\n",
    "\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "26576d92-24ee-4d39-91c7-d2fcf712d501",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data_utils import load_data_from_arguments\n",
    "from nnlib.nnlib.data_utils.wrappers import SubsetDataWrapper, LabelSubsetWrapper, ResizeImagesWrapper\n",
    "from nnlib.nnlib.data_utils.base import get_loaders_from_datasets, get_input_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7dfdefe6-ce2b-4de6-aa06-fc64027418f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b).float()\n",
    "\n",
    "def load_all_data(saved_data):\n",
    "    args = saved_data['args']\n",
    "    examples ,_ ,_ ,_ = load_data_from_arguments(args, build_loaders=False)\n",
    "    examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "    \n",
    "    return examples\n",
    "\n",
    "def preds_recalibrate(saved_data, all_data, n_data=100):\n",
    "    all_examples = saved_data['all_examples'] ## train, val data\n",
    "    args = saved_data['args']\n",
    "    #examples ,_ ,_ ,_ = load_data_from_arguments(args, build_loaders=False)\n",
    "    #examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "    \n",
    "    assert len(all_data) >= n_data\n",
    "    np.random.seed(args.seed)\n",
    "    include_indices = all_examples.include_indices\n",
    "    exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]\n",
    "    if n_data >= len(exclude_indices):\n",
    "        raise ValueError(f\" The number of recalibration data is larger than that of all data\")\n",
    "    exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)\n",
    "    examples = SubsetDataWrapper(all_data, include_indices=exclude_indices)\n",
    "    \n",
    "    return examples\n",
    "\n",
    "def calc_recalibration(ps, n_bins):\n",
    "    #idx = torch.bucketize(ps, n_bins, right=True) - 1\n",
    "    idx = idx_bins(ps, n_bins)\n",
    "    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins\n",
    "    bin_true = torch.bincount(idx, weights=ps, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins weighted by labels\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.filterwarnings('ignore')\n",
    "        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "    return bin_mean[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9fdb4d88-133a-4bc9-80fa-0d4c8f4bdaaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#saved_data['args'].which_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "89db7d2b-4064-4ac0-989d-a7f77a3428a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "# トレーニングデータをダウンロード\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)\n",
    "\n",
    "# テストデータをダウンロード\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6e925525-a067-44bd-9031-b115afe94397",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#len(trainset.data[torch.tensor(trainset.targets) == 3])\n",
    "#len(trainset.targets)\n",
    "#len(testset.data[torch.tensor(testset.targets) == 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "69c9646e-e4fb-44d5-8c0d-3b188eabad00",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NestedDict(dict):\n",
    "    def __missing__(self, key):\n",
    "        self[key] = type(self)()\n",
    "        return self[key]\n",
    "\n",
    "results_dir = \"results\"\n",
    "\n",
    "exp_name = \"fcmi-mnist-4vs9-CNN\"\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_ecmi_nonrecal.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results = pickle.load(f)\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_ecmi_recal_reuse_hist_umb.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results2 = pickle.load(f)\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_ecmi_recal_hist_umb.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results3 = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cf16b44-14fd-444d-b4e4-ecda1b5add4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_softmax(preds):\n",
    "    \n",
    "    if not torch.all(torch.abs(torch.sum(preds, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        preds = preds.softmax(1)\n",
    "    else:\n",
    "        print(\"already softmax prob.\")\n",
    "    \n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cf09583b-620d-4017-9e3a-c8a94ed73866",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4739478230942338"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.sqrt((8*(np.array([results[4000][200][i]['mis_list_umb'] for i in range(len(results[4000][200]))]).mean() + math.floor(4000 * (1/3)) * np.log(2)))/4000)\n",
    "import math\n",
    "n = 4000\n",
    "cmi = np.array([results[n][200][i]['mis_list_umb'] for i in range(len(results[n][200]))]).mean()\n",
    "cmi += math.floor(n ** (1/3)) * np.log(2)\n",
    "np.sqrt(8*cmi / n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3453552c-d80a-48a5-b0a1-2e2cb49b3870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.20413726799326548"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#ece_gap = np.tile(results[n][200][0]['exp_gap_ece_umb'], 4000).reshape(-1,1)\n",
    "#ece_gap = np.array([results[n][200][i]['exp_gap_ece_umb'] for i in range(len(results[n][200]))])[0]\n",
    "#mutual_info_classif(ece_gap, masks[0], discrete_features=[False]).sum()\n",
    "#len(masks[0])\n",
    "#masks\n",
    "#ece_loss = _ECELoss(n_bins=10, method='quantile', logits=cur_preds[train_idx])\n",
    "#ece_gap = np.abs(ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx]) - ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx]))\n",
    "#mutual_info_classif(np.tile(ece_gap, len(cur_mask)).reshape(-1,1), cur_mask, discrete_features=[False]).sum()\n",
    "\n",
    "#results[75][200][0]\n",
    "#results[n][200]\n",
    "cmi = np.array([results[n][200][i]['mis_ece_list_umb'] for i in range(len(results[n][200]))]).mean()\n",
    "cmi += math.floor(n ** (1/3)) * np.log(2)\n",
    "np.sqrt(16*cmi / n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "712fd980-ac75-417a-8e28-2da1001d0fce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.006872499"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.array([results[4000][200][i]['exp_gap_ece_umb2'] for i in range(len(results[4000][200]))]).mean()\n",
    "np.array([results[n][200][i]['exp_gap_ece_uwb2'] for i in range(len(results[n][200]))]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4aa7d655-a843-4321-9672-bffe9493bd24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.006371367"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.sqrt((8*(np.array([results[4000][200][i]['mis_list_umb'] for i in range(len(results[4000][200]))]).mean() + math.floor(4000 * (1/3)) * np.log(2)))/4000)\n",
    "cmi = np.array([results2[n][200][i]['mis_list_umb'] for i in range(len(results2[n][200]))]).mean()\n",
    "cmi += math.floor(n ** (1/3)) * np.log(2)\n",
    "#np.sqrt(2*cmi / n)\n",
    "np.array([results2[n][200][i]['exp_gap_ece_umb2'] for i in range(len(results2[n][200]))]).mean() #+ np.array([results2[n][200][i]['exp_gap_ece_umb2'] for i in range(len(results2[n][200]))]).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4077424e-b867-448f-a52b-04d58b2ac8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#np.array([results2[n][200][i]['exp_train_ece_umb'] for i in range(len(results2[n][200]))]).mean()\n",
    "#np.array([results3[n][200][i]['exp_train_ece_umb'] for i in range(len(results2[n][200]))]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "eff0de95-b0cf-4967-838e-7160965f239c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.008591145"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([results3[n][200][i]['exp_gap_ece_umb2'] for i in range(len(results3[n][200]))]).mean() #+ np.array([results3[n][200][i]['exp_gap_ece_umb2'] for i in range(len(results3[n][200]))]).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8769357f-c76b-4df0-83f9-2dc9fa720f78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "#confidences, predictions = torch.max(cur_preds[train_idx], 1)\n",
    "#accuracies = cur_labels[train_idx]\n",
    "class _ECELoss(nn.Module):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Calibration Error of a model.\n",
    "    (This isn't necessary for temperature scaling, just a cool metric).\n",
    "\n",
    "    The input to this loss is the logits of a model, NOT the softmax scores.\n",
    "\n",
    "    This divides the confidence outputs into equally-sized interval bins.\n",
    "    In each bin, we compute the confidence gap:\n",
    "\n",
    "    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |\n",
    "\n",
    "    We then return a weighted average of the gaps, based on the number\n",
    "    of samples in each bin\n",
    "\n",
    "    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.\n",
    "    \"Obtaining Well Calibrated Probabilities Using Bayesian Binning.\" AAAI.\n",
    "    2015.\n",
    "    \"\"\"\n",
    "    def __init__(self, n_bins=15, method='uniform', logits=None):\n",
    "        \"\"\"\n",
    "        n_bins (int): number of confidence interval bins\n",
    "        \"\"\"\n",
    "        super(_ECELoss, self).__init__()\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, n_bins + 1)\n",
    "        elif method == 'quantile':\n",
    "            conf, pred = logits.max(1)\n",
    "            conf[pred == 0] = 1 - conf[pred == 0]\n",
    "            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))\n",
    "        \n",
    "        bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        self.bin_lowers = bin_boundaries[:-1]\n",
    "        self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    def forward(self, logits, labels):\n",
    "        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):\n",
    "            print(\"make softmax prob.\")\n",
    "            softmaxes = F.softmax(logits, dim=1)\n",
    "        else:\n",
    "            softmaxes = logits\n",
    "        confidences, predictions = torch.max(softmaxes, 1)\n",
    "        confidences[predictions==0] = 1 - confidences[predictions==0]\n",
    "        #accuracies = predictions.eq(labels)\n",
    "        accuracies = labels\n",
    "\n",
    "        ece = torch.zeros(1, device=logits.device)\n",
    "        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n",
    "            prop_in_bin = in_bin.float().mean()\n",
    "            if prop_in_bin.item() > 0:\n",
    "                accuracy_in_bin = accuracies[in_bin].float().mean()\n",
    "                avg_confidence_in_bin = confidences[in_bin].mean()\n",
    "                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin\n",
    "\n",
    "        return ece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7949e838-512e-4619-b3d9-9185b85d4384",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f48ef69a-8aa8-4a6c-817f-eb9ed751aa9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n",
      "make softmax prob.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([0.00045314, 0.00045314, 0.00045314, 0.00045314, 0.00045314,\n",
       "       0.00045314, 0.00045314, 0.00045314, 0.00045314, 0.00045314],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#ece_loss = _ECELoss()\n",
    "ece_loss = _ECELoss(n_bins=math.floor(n**1/3), method='quantile', logits=cur_preds[train_idx])\n",
    "#ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx])\n",
    "#ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx])\n",
    "#print(abs(ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx]) - ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx])))\n",
    "\n",
    "ece_gap = abs(ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx]) - ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx]))\n",
    "#mutual_info_classif(ece_gap.tile(4000).reshape(-1,1), cur_mask, discrete_features=[False, False]).sum()\n",
    "#np.tile(ece_gap, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "8bb87065-92c9-4a1f-aca8-43806c9611bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in range(n):\n",
    "    ms = np.array([p[idx] for p in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    #for i in range(len(ps)):\n",
    "    #    if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):\n",
    "    #        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values\n",
    "    #    else:\n",
    "    #        ps[i] = torch.max(ps[i], dim=1).values ## predictive values\n",
    "    #    ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "0fa277a3-fd44-4489-9b53-76f134a6d0af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ece_loss = _ECELoss(n_bins=math.floor(n**1/3), method='quantile', logits=ps[0])\n",
    "#ece_loss.forward(ps[0],ms[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "edf1bd17-9c46-409e-8003-b6f29998c1ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 2])"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "#pred = torch.softmax(torch.randn(100,10),1)\n",
    "idx = np.arange(0, pred.shape[1])\n",
    "target_label = 3\n",
    "\n",
    "#binary_preds = torch.zeros(len(pred), 2)\n",
    "#binary_preds[:,0] = 1 - preds[:,idx==target_label]\n",
    "#binary_preds[:,1] = preds[:,idx==target_label]\n",
    "#labels[labels!=target_label] = 0\n",
    "#labels[labels==target_label] = 1\n",
    "binary_preds = torch.zeros(len(pred), 2)\n",
    "binary_preds[:,0] = 1 - pred[:,target_label]\n",
    "binary_preds[:,1] = pred[:,target_label]\n",
    "#labels[labels!=target_label] = 0\n",
    "#labels[labels==target_label] = 1\n",
    "binary_preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "a3b84c53-208f-4022-bf36-fb2be79eae14",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_pred_to_binary(preds, labels, target_label=3):\n",
    "    if (preds.shape[1]-1) < target_label:\n",
    "        raise ValueError(f\" This target label is not included.\")\n",
    "    idx = np.arange(0,len(torch.unique(labels)))\n",
    "    #preds = preds.softmax(1)\n",
    "    binary_preds = torch.zeros(len(preds), 2)\n",
    "    binary_preds[:,0] = 1 - preds[:,idx==target_label].sum(1)\n",
    "    binary_preds[:,1] = preds[:,idx==target_label].sum(1)\n",
    "    labels[labels!=target_label] = 0\n",
    "    labels[labels==target_label] = 1\n",
    "    return binary_preds, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "dcc66876-9931-40ea-b3ab-c5b76afc5a1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 2])"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, label = torch.max(pred,1)\n",
    "multi_pred_to_binary(pred, label)[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "35652224-9a57-467b-bb5a-8b9b873656f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Dataset mnist has no validation set. Splitting the training set...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9472\n",
      "calling function __init__ of MNIST with the following parameters:\n",
      "\tdata_augmentation: False\n",
      "calling function build_datasets of MNIST with the following parameters:\n",
      "\tdata_dir: None\n",
      "\tval_ratio: 0.2\n",
      "\tnum_train_examples: None\n",
      "\tseed: 0\n",
      "\tdownload: True\n",
      "Dataset mnist is loaded\n",
      "\ttrain_samples: 48000\n",
      "\tval_samples: 12000\n",
      "\ttest_samples: 10000\n",
      "\texample_shape: torch.Size([1, 28, 28])\n",
      "10000\n"
     ]
    }
   ],
   "source": [
    "all_examples = saved_data['all_examples'] ## train, val data\n",
    "args = saved_data['args']\n",
    "\n",
    "include_indices = all_examples.include_indices\n",
    "#exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]\n",
    "#len(exclude_indices)\n",
    "#np.isin(np.arange(len(all_data)), include_indices)\n",
    "print(len(all_data))\n",
    "#load_all_data(saved_data=saved_data)\n",
    "from modules.data_utils import load_data_from_arguments\n",
    "\n",
    "_, _, examples, _ = load_data_from_arguments(args, build_loaders=False)\n",
    "print(len(examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "761e0d80-86f1-442c-ab91-8eb50c6df1e3",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=0/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.55it/s]\n",
      "INFO:root:Dataset mnist has no validation set. Splitting the training set...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "calling function __init__ of MNIST with the following parameters:\n",
      "\tdata_augmentation: False\n",
      "calling function build_datasets of MNIST with the following parameters:\n",
      "\tdata_dir: None\n",
      "\tval_ratio: 0.2\n",
      "\tnum_train_examples: None\n",
      "\tseed: 0\n",
      "\tdownload: True\n",
      "Dataset mnist is loaded\n",
      "\ttrain_samples: 48000\n",
      "\tval_samples: 12000\n",
      "\ttest_samples: 10000\n",
      "\texample_shape: torch.Size([1, 28, 28])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 37.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 37.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=1/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:02<00:00, 15.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.65it/s]\n",
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 37.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 36.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=2/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.80it/s]\n",
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 37.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 35.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=3/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.45it/s]\n",
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 35.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 35.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from /Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/results/fcmi-mnist-4vs9-CNN/n=4000,lr=0.001,seed=0,S_seed=4/checkpoints/epoch199.mdl\n",
      "[None, 1, 28, 28]\n",
      "[None, 32, 14, 14]\n",
      "[None, 32, 7, 7]\n",
      "[None, 64, 3, 3]\n",
      "[None, 256, 1, 1]\n",
      "[None, 256]\n",
      "[None, 128]\n",
      "output.shape: [None, 2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 32/32 [00:01<00:00, 16.74it/s]\n",
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 40.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 1/1 [00:00<00:00, 35.77it/s]\n"
     ]
    }
   ],
   "source": [
    "#n = 75\n",
    "#n = 250\n",
    "n = 4000\n",
    "lr = 0.001\n",
    "seed = 0\n",
    "#n_S_seeds = 30\n",
    "n_S_seeds = 5\n",
    "epoch = 200 ## full: 200 ; 200 epochsに対する結果\n",
    "n_batch = 256\n",
    "\n",
    "preds = []\n",
    "labels = []\n",
    "masks = []\n",
    "\n",
    "#recalibrate_data = False\n",
    "recalibrate_data = True\n",
    "if recalibrate_data:\n",
    "    preds_recab = []\n",
    "    labels_recab = []\n",
    "\n",
    "#n_bins = torch.tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) ## Uniform\n",
    "#n_bins = int(n ** (1/4)) ## l1\n",
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "bins = []\n",
    "\n",
    "for S_seed in range(n_S_seeds):\n",
    "    dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "    dir_path = os.path.join(os.getcwd() + \"/results/\", \"fcmi-mnist-4vs9-CNN\", dir_name)\n",
    "    if not os.path.exists(dir_path):\n",
    "        print(f\"Did not find results for {dir_name}\")\n",
    "        continue\n",
    "    \n",
    "    with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "        saved_data = pickle.load(f)\n",
    "    \n",
    "    model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'), methods=methods, device=\"cpu\")\n",
    "    if 'all_examples_wo_data_aug' in saved_data:\n",
    "        all_examples = saved_data['all_examples_wo_data_aug']\n",
    "    else:\n",
    "        all_examples = saved_data['all_examples']\n",
    "    \n",
    "    cur_preds = check_softmax(utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred']) ## (2*n, num_classes)\n",
    "    cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    cur_mask = saved_data['mask'] #(n_labels)\n",
    "    train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "    cur_bins = torch.bucketize(cur_preds[train_idx].softmax(1).max(1).values, n_bins, right=True) - 1\n",
    "    #id = np.arange(0,len(torch.unique(cur_labels)))\n",
    "    #cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'].softmax(1)\n",
    "    #cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    #a = torch.zeros_like(cur_preds)\n",
    "    #a[:,0] = cur_preds[:,id!=1].sum(1)\n",
    "    #a[:,1] = cur_preds[:,id==1].sum(1)\n",
    "    #cur_preds = a\n",
    "    #cur_bins = torch.bucketize(cur_preds[train_idx].softmax(1).max(1).values, n_bins, right=True) - 1\n",
    "    \n",
    "    if recalibrate_data:\n",
    "        if S_seed == 0: ## Loading the all dataset to prepare the recablation dataset at the first iteration\n",
    "            all_data = load_all_data(saved_data=saved_data)\n",
    "        recab_examples = preds_recalibrate(saved_data=saved_data, all_data=all_data, n_data=100)\n",
    "        cur_preds_recab = check_softmax(utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=n_batch)['pred']) ## (2*n, num_classes)\n",
    "        cur_labels_recab = utils.apply_on_dataset(model=model, dataset=recab_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "        preds_recab.append(cur_preds_recab)\n",
    "        labels_recab.append(cur_labels_recab)\n",
    "    \n",
    "    preds.append(cur_preds)\n",
    "    labels.append(cur_labels)\n",
    "    masks.append(cur_mask)\n",
    "    bins.append(cur_bins.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d15e13-40fa-4a6a-87f6-867fdabbdc6c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "573e64a3-4dc8-4f1f-b3a3-d8346f591419",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "class _LabelLoss(nn.Module):\n",
    "    def __init__(self, n_bins=15, method='uniform', logits=None):\n",
    "        \"\"\"\n",
    "        n_bins (int): number of confidence interval bins\n",
    "        \"\"\"\n",
    "        super(_LabelLoss, self).__init__()\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, n_bins + 1)\n",
    "        elif method == 'quantile':\n",
    "            conf, pred = logits.max(1)\n",
    "            conf[pred==0] = 1 - conf[pred==0]\n",
    "            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))\n",
    "        \n",
    "        bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        self.bin_lowers = bin_boundaries[:-1]\n",
    "        self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    def forward(self, logits, labels):\n",
    "        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):\n",
    "            print(\"make softmax prob.\")\n",
    "            softmaxes = F.softmax(logits, dim=1)\n",
    "        else:\n",
    "            softmaxes = logits\n",
    "        \n",
    "        confidences, predictions = torch.max(softmaxes, 1)\n",
    "        confidences[predictions==0] = 1 - confidences[predictions==0]\n",
    "        accuracies = labels\n",
    "        #accuracies = predictions.eq(labels)\n",
    "\n",
    "        #acc = torch.zeros(1, device=logits.device)\n",
    "        acc_list = []\n",
    "        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n",
    "            prop_in_bin = in_bin.float().mean()\n",
    "            if prop_in_bin.item() > 0:\n",
    "                accuracy_in_bin = accuracies[in_bin].float().mean()\n",
    "                acc_list.append(accuracy_in_bin)\n",
    "                #avg_confidence_in_bin = confidences[in_bin].mean()\n",
    "                #ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin\n",
    "                #ece += accuracy_in_bin\n",
    "\n",
    "        return acc_list\n",
    "\n",
    "def compute_labelloss(preds, mask, dataset, loss, target_label=None):\n",
    "    labels = [y for x, y in dataset]\n",
    "    labels = torch.tensor(labels).long()\n",
    "    if target_label is not None:\n",
    "        labels[labels!=target_label] = 0\n",
    "        labels[labels==target_label] = 1\n",
    "    indices = 2*np.arange(len(mask)) + mask\n",
    "\n",
    "    #ece = ece_loss.forward(preds[indices], labels[indices])\n",
    "    acc_list = loss.forward(preds[indices], labels[indices])\n",
    "    \n",
    "    #return utils.to_numpy(ece)\n",
    "    return acc_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "10488ae1-ad6a-463a-81d1-dcda53ce5b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "pred_tr, pred_te = cur_preds[train_idx], cur_preds[test_idx]\n",
    "labels = [y for x, y in all_examples]\n",
    "labels = torch.tensor(labels).long()\n",
    "\n",
    "ls_tr, ls_te = labels[train_idx], labels[test_idx]\n",
    "\n",
    "conf, pred = pred_tr.max(1)\n",
    "conf[pred==0] = 1 - conf[pred==0]\n",
    "bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, B + 1)))\n",
    "bin_lowers = bin_boundaries[:-1]\n",
    "bin_uppers = bin_boundaries[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "eb6faaa7-1ae1-433f-b16a-f4b100c8b14b",
   "metadata": {},
   "outputs": [],
   "source": [
    "softmaxes = pred_tr\n",
    "confidences, predictions = torch.max(softmaxes, 1)\n",
    "confidences[predictions==0] = 1 - confidences[predictions==0]\n",
    "#accuracies = labels\n",
    "accuracies = ls_tr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "ae85e2e5-c3d9-4227-a90c-4c4f4fef88ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4247)"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc_list = []\n",
    "for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n",
    "    in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n",
    "    prop_in_bin = in_bin.float().mean()\n",
    "    #print(in_bin, prop_in_bin)\n",
    "    if prop_in_bin.item() > 0:\n",
    "        accuracy_in_bin = accuracies[in_bin].float().mean()\n",
    "        acc_list.append(accuracy_in_bin)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "38a33201-1227-4ec3-b338-6ad77ba37f05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[tensor(0.5130)]"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = _LabelLoss(n_bins=B, method='quantile', logits=cur_preds[train_idx])\n",
    "compute_labelloss(preds=cur_preds, mask=cur_mask, dataset=all_examples, loss=loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b11a68-c57d-44b6-95c1-f23d8649a9b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a05e0823-c5f6-44a2-b50c-9e79814e7d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import softmax, nn\n",
    "from utils.metrics import get_kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 393,
   "id": "e4dc3bbe-65a4-4358-8cbe-5cdb4fcb3bd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_bandwidth(f, device='cpu'):\n",
    "    \"\"\"\n",
    "    Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE).\n",
    "\n",
    "    :param f: The vector containing the probability scores, shape [num_samples, num_classes]\n",
    "    :param device: The device type: 'cpu' or 'cuda'\n",
    "\n",
    "    :return: The bandwidth of the kernel\n",
    "    \"\"\"\n",
    "    bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5)))\n",
    "    max_b = -1\n",
    "    max_l = 0\n",
    "    n = len(f)\n",
    "    for b in bandwidths:\n",
    "        log_kern = get_kernel(f, b, device)\n",
    "        #log_fhat = torch.logsumexp(log_kern, 1) - torch.log(n-1)\n",
    "        log_fhat = torch.logsumexp(log_kern, 1) - np.log(n-1)\n",
    "        l = torch.sum(log_fhat)\n",
    "        if l > max_l:\n",
    "            max_l = l\n",
    "            max_b = b\n",
    "\n",
    "    return max_b\n",
    "\n",
    "def get_ratio_canonical(f, y, bandwidth, device='cpu'):\n",
    "    if f.shape[1] > 60:\n",
    "        # Slower but more numerically stable implementation for larger number of classes\n",
    "        return get_ratio_canonical_log(f, y, bandwidth, device='cpu')\n",
    "\n",
    "    log_kern = get_kernel(f, bandwidth, device)\n",
    "    kern = torch.exp(log_kern).to(torch.float32)\n",
    "\n",
    "    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)\n",
    "    kern_y = torch.matmul(kern, y_onehot)\n",
    "    den = torch.sum(kern, dim=1)\n",
    "    # to avoid division by 0\n",
    "    den = torch.clamp(den, min=1e-10)\n",
    "\n",
    "    ratio = kern_y / den.unsqueeze(-1)\n",
    "    #ratio = torch.sum(torch.abs(ratio - f)**p, dim=1)\n",
    "\n",
    "    return ratio\n",
    "\n",
    "def get_ratio_canonical_log(f, y, bandwidth, device='cpu'):\n",
    "    log_kern = get_kernel(f, bandwidth, device)\n",
    "    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)\n",
    "    log_y = torch.log(y_onehot)\n",
    "    log_den = torch.logsumexp(log_kern, dim=1)\n",
    "    #final_ratio = 0\n",
    "    ratio = torch.zeros_like(y_onehot)\n",
    "    for k in range(f.shape[1]):\n",
    "        log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0))\n",
    "        log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den\n",
    "        ratio[:,k] = torch.exp(log_inner_ratio)\n",
    "        #inner_ratio = torch.exp(log_inner_ratio)\n",
    "        #inner_diff = torch.abs(inner_ratio - f[:, k])**p\n",
    "        #final_ratio += inner_diff\n",
    "    \n",
    "    return ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66528cfa-caac-4daf-8f8a-4e22b0b4dee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = sklearn.utils.validation.column_or_1d(y)\n",
    "    y_pred = np.argmax(p_pred, axis=1)\n",
    "    y_pred = sklearn.utils.validation.column_or_1d(y_pred)\n",
    "    if n_classes is None:\n",
    "        n_classes = np.unique(np.concatenate([y, y_pred])).shape[0]\n",
    "\n",
    "    # Compute bin means\n",
    "    bin_range = [1 / n_classes, 1]\n",
    "    #bin_range = [0, 1]\n",
    "    bins = np.linspace(bin_range[0], bin_range[1], n_bins + 1)\n",
    "\n",
    "    # Find prediction confidence\n",
    "    p_max = np.max(p_pred, axis=1)\n",
    "\n",
    "    # Compute empirical accuracy\n",
    "    empirical_acc = scipy.stats.binned_statistic(p_max, (y_pred == y).astype(int),\n",
    "                                                 bins=n_bins, range=bin_range)[0]\n",
    "    nanindices = np.where(np.logical_not(np.isnan(empirical_acc)))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1039,
   "id": "a9ea9653-2f6d-47bf-b5d0-63ebd9dbca98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.04745001186834022"
      ]
     },
     "execution_count": 1039,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sklearn\n",
    "import scipy\n",
    "from utils.metrics import compute_bins, idx_bins\n",
    "\n",
    "n_bins = cur_bins[0]\n",
    "\n",
    "y = sklearn.utils.validation.column_or_1d(cur_labels[train_idx])\n",
    "y_pred = np.argmax(cur_preds[train_idx], axis=1)\n",
    "y_pred = sklearn.utils.validation.column_or_1d(y_pred)\n",
    "n_classes = np.unique(np.concatenate([y, y_pred])).shape[0]\n",
    "\n",
    "#bin_range = [1 / n_classes, 1]\n",
    "bin_range = [0, 1]\n",
    "bins = np.linspace(bin_range[0], bin_range[1], n_bins + 1)\n",
    "#bins = np.quantile(cur_preds[train_idx].max(1).values, q=np.linspace(bin_range[0], bin_range[1], n_bins + 1))\n",
    "\n",
    "p_max = np.max(cur_preds[train_idx].numpy(), axis=1)\n",
    "empirical_acc = scipy.stats.binned_statistic(p_max, (y_pred == y).astype(int),\n",
    "                                                 bins=n_bins, range=bin_range)[0]\n",
    "\n",
    "nanindices = np.where(np.logical_not(np.isnan(empirical_acc)))[0]\n",
    "\n",
    "calibrated_acc = np.linspace(bin_range[0] + bin_range[1] / (2 * n_bins), bin_range[1] - bin_range[1] / (2 * n_bins), n_bins)\n",
    "weights_ece = np.histogram(p_max, bins)[0][nanindices]\n",
    "np.average(abs(empirical_acc[nanindices] - calibrated_acc[nanindices]),\n",
    "                         weights=weights_ece)\n",
    "\n",
    "#bins\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1040,
   "id": "8a29a93e-f744-47e1-94cf-d3e30c400fe1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([       nan,        nan,        nan,        nan,        nan,\n",
       "       1.        , 0.5       , 1.        , 0.28571429, 0.99623777])"
      ]
     },
     "execution_count": 1040,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "empirical_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 594,
   "id": "2622ef05-5fa7-41be-aaea-803886780859",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4869)"
      ]
     },
     "execution_count": 594,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf = cur_preds[train_idx].max(1).values\n",
    "#conf[cur_labels[train_idx]==0] = 1 - conf[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "bins = compute_bins(num_bins=n_bins, confidences=conf, method='quantile')\n",
    "#bins = compute_bins(num_bins=n_bins)\n",
    "conf_bin = torch.zeros(len(bins), device=conf.device, dtype=conf.dtype)\n",
    "count_bin = torch.zeros(len(bins), device=conf.device, dtype=conf.dtype)\n",
    "label_bin = torch.zeros(len(bins), device=cur_labels[train_idx].device, dtype=cur_labels[train_idx].dtype)\n",
    "idx = idx_bins(conf, bins)\n",
    "\n",
    "## recalibration\n",
    "#bin_total = torch.bincount(idx, minlength=len(bins)-1).float().to(conf.device) ## the number of samples per bins\n",
    "#bin_true = (torch.bincount(idx, weights=cur_labels[train_idx], minlength=len(bins)-1)).float().to(conf.device)\n",
    "#bin_true = (torch.bincount(idx, weights=conf, minlength=len(bins)-1)).float().to(conf.device)\n",
    "#bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "#conf = bin_mean[idx]\n",
    "\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(conf))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=conf)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels[train_idx])\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 661,
   "id": "7ab762b1-d8dc-4663-a075-cd2b411953e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import functional as F\n",
    "class _ECELoss(nn.Module):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Calibration Error of a model.\n",
    "    (This isn't necessary for temperature scaling, just a cool metric).\n",
    "\n",
    "    The input to this loss is the logits of a model, NOT the softmax scores.\n",
    "\n",
    "    This divides the confidence outputs into equally-sized interval bins.\n",
    "    In each bin, we compute the confidence gap:\n",
    "\n",
    "    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |\n",
    "\n",
    "    We then return a weighted average of the gaps, based on the number\n",
    "    of samples in each bin\n",
    "\n",
    "    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.\n",
    "    \"Obtaining Well Calibrated Probabilities Using Bayesian Binning.\" AAAI.\n",
    "    2015.\n",
    "    \"\"\"\n",
    "    def __init__(self, n_bins=15, method='uniform', logits=None):\n",
    "        \"\"\"\n",
    "        n_bins (int): number of confidence interval bins\n",
    "        \"\"\"\n",
    "        super(_ECELoss, self).__init__()\n",
    "        if method == 'uniform':\n",
    "            bin_boundaries = torch.linspace(0, 1, n_bins + 1)\n",
    "        elif method == 'quantile':\n",
    "            conf, pred = logits.max(1)\n",
    "            conf[pred==0] = 1 - conf[pred==0]\n",
    "            bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))\n",
    "        \n",
    "        bin_boundaries[0], bin_boundaries[-1] = 0., 1.\n",
    "        self.bin_lowers = bin_boundaries[:-1]\n",
    "        self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    def forward(self, logits, labels):\n",
    "        if not torch.all(torch.abs(torch.sum(logits, dim=1) - 1) < 1e-10):\n",
    "            print(\"make softmax prob.\")\n",
    "            softmaxes = F.softmax(logits, dim=1)\n",
    "        else:\n",
    "            softmaxes = logits\n",
    "        confidences, predictions = torch.max(softmaxes, 1)\n",
    "        accuracies = predictions.eq(labels)\n",
    "\n",
    "        ece = torch.zeros(1, device=logits.device)\n",
    "        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n",
    "            prop_in_bin = in_bin.float().mean()\n",
    "            if prop_in_bin.item() > 0:\n",
    "                accuracy_in_bin = accuracies[in_bin].float().mean()\n",
    "                avg_confidence_in_bin = confidences[in_bin].mean()\n",
    "                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin\n",
    "\n",
    "        return ece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3447571c-2d8d-40f5-aaad-34f6b8f9a2e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 662,
   "id": "3683b3ea-dee8-4fb8-9a72-53ff1d80be70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.7684e-07, 9.9995e-01,\n",
       "        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],\n",
       "       dtype=torch.float64)"
      ]
     },
     "execution_count": 662,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conf, pred = cur_preds[train_idx].max(1)\n",
    "conf[pred==0] = 1 - conf[pred==0]\n",
    "#conf[cur_labels[train_idx]==0] = 1 - conf[cur_labels[train_idx]==0]\n",
    "torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 663,
   "id": "a7405065-88cf-451f-9f58-3248972a7ed5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n",
      "tensor([0.2690])\n",
      "make softmax prob.\n",
      "tensor([0.2618])\n"
     ]
    }
   ],
   "source": [
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "ece_loss = _ECELoss(n_bins=n_bins)\n",
    "print(ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx]))\n",
    "print(ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 696,
   "id": "3cdac3f3-f7c6-46bd-bd65-aaf249f1392a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "make softmax prob.\n",
      "tensor([0.2690])\n",
      "make softmax prob.\n",
      "tensor([0.2609])\n"
     ]
    }
   ],
   "source": [
    "ece_loss = _ECELoss(n_bins=n_bins, method='quantile', logits=cur_preds[train_idx])\n",
    "#ece_loss = _ECELoss(n_bins=n_bins, method='quantile', logits=conf)\n",
    "print(ece_loss.forward(cur_preds[train_idx], cur_labels[train_idx]))\n",
    "print(ece_loss.forward(cur_preds[test_idx], cur_labels[test_idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 690,
   "id": "32853e39-df33-4da0-8f3f-b696fb86af13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def expected_calibration_error(probs, labels, n_bins=10, method='uniform'):\n",
    "    \"\"\"\n",
    "    Compute the Expected Calibration Error (ECE) using uniform mass binning.\n",
    "    \n",
    "    Args:\n",
    "    probs (torch.Tensor): Tensor of predicted probabilities (softmax outputs of model).\n",
    "    labels (torch.Tensor): Tensor of ground truth labels.\n",
    "    n_bins (int): Number of bins to use for calibration error evaluation.\n",
    "    \n",
    "    Returns:\n",
    "    float: The expected calibration error.\n",
    "    \"\"\"\n",
    "    # Ensure probs and labels are tensors\n",
    "    #probs = torch.tensor(probs, dtype=torch.float32)\n",
    "    #labels = torch.tensor(labels, dtype=torch.long)\n",
    "\n",
    "    # Get the predictions and confidence\n",
    "    confidences, predictions = torch.max(probs, dim=1)\n",
    "    \n",
    "    # Sort by confidences for binning\n",
    "    indices = torch.argsort(confidences)\n",
    "    sorted_confidences = confidences[indices]\n",
    "    sorted_predictions = predictions[indices]\n",
    "    sorted_labels = labels[indices]\n",
    "\n",
    "    # Bin the data into quantiles\n",
    "    if method == 'uniform':\n",
    "        #bin_boundaries = torch.linspace(0, len(probs), n_bins + 1, dtype=torch.int64)\n",
    "        bin_boundaries = torch.linspace(0, 1, n_bins + 1, dtype=torch.int64)\n",
    "    elif method == 'quantile':\n",
    "        conf, pred = probs.max(1)\n",
    "        bin_boundaries = torch.tensor(np.quantile(conf, torch.linspace(0, 1, n_bins + 1)), dtype=torch.int64)\n",
    "    bin_lowers = bin_boundaries[:-1]\n",
    "    bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    # Calculate ECE\n",
    "    ece = 0.0\n",
    "    for lower, upper in zip(bin_lowers, bin_uppers):\n",
    "        bin_confidences = sorted_confidences[lower:upper]\n",
    "        bin_labels = sorted_labels[lower:upper]\n",
    "        bin_size = upper - lower\n",
    "        \n",
    "        if bin_size > 0:\n",
    "            # Average confidence in bin\n",
    "            average_confidence = bin_confidences.mean()\n",
    "            \n",
    "            # Accuracy in bin\n",
    "            correct_predictions = (sorted_predictions[lower:upper] == bin_labels).float()\n",
    "            accuracy = correct_predictions.mean()\n",
    "            \n",
    "            # Weight by the number of samples in the bin\n",
    "            bin_weight = bin_size / float(len(probs))\n",
    "            \n",
    "            # Contribution to ECE\n",
    "            bin_error = torch.abs(average_confidence - accuracy) * bin_weight\n",
    "            ece += bin_error\n",
    "\n",
    "    return ece.item()\n",
    "\n",
    "# Uncomment below to use the function\n",
    "#probs = torch.rand(1000, 10)  # Random probabilities for 10 classes\n",
    "#labels = torch.randint(0, 10, (1000,))  # Random true labels\n",
    "#print(f\"Expected Calibration Error: {expected_calibration_error(probs, labels)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 695,
   "id": "c08c2033-a158-4626-a04e-02e8fd154633",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.730001688178163e-05\n",
      "0.00012346250878181309\n",
      "2.730001688178163e-05\n",
      "0.00012346250878181309\n"
     ]
    }
   ],
   "source": [
    "print(expected_calibration_error(cur_preds[train_idx], cur_labels[train_idx], n_bins=n_bins))\n",
    "print(expected_calibration_error(cur_preds[test_idx], cur_labels[test_idx], n_bins=n_bins))\n",
    "\n",
    "print(expected_calibration_error(cur_preds[train_idx], cur_labels[train_idx], n_bins=n_bins, method='quantile'))\n",
    "print(expected_calibration_error(cur_preds[test_idx], cur_labels[test_idx], n_bins=n_bins, method='quantile'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "b91890cf-5b39-4351-a3ed-7de6609dc82d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/methods/calibration.py:300: RuntimeWarning: Mean of empty slice.\n",
      "  self.prob_class_1 = interpolate_nan(np.array([y[digitized == i].mean() for i in range(1, len(self.binning))]))\n",
      "/Users/fujisawa/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/numpy/core/_methods.py:189: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    }
   ],
   "source": [
    "#bin_mean[idx]\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "from methods.calibration import HistogramBinning\n",
    "\n",
    "temp = HistogramBinning(mode='equal_freq',n_bins=15)\n",
    "#temp = HistogramBinning(mode='equal_width',n_bins=15)\n",
    "temp.fit(cur_preds[train_idx].numpy(), cur_labels[train_idx].numpy())\n",
    "res2 = torch.tensor(temp.predict_proba(cur_preds[test_idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "0fe79b00-3860-41d2-ad8c-b96b4d023f19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1.],\n",
       "        [0., 1.],\n",
       "        [0., 1.],\n",
       "        ...,\n",
       "        [0., 1.],\n",
       "        [0., 1.],\n",
       "        [1., 0.]])"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 482,
   "id": "6607bafe-cd79-4816-b34a-fbd7a0112af9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9797)\n",
      "tensor(0.9657)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/fujisawa/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/methods/calibration.py:276: RuntimeWarning: Mean of empty slice.\n",
      "  self.prob_class_1 = [y[digitized == i].mean() for i in range(1, len(self.binning))]\n",
      "/Users/fujisawa/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/numpy/core/_methods.py:189: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    }
   ],
   "source": [
    "from methods.calibration import CalibrationMethod, TemperatureScaling, HistogramBinning\n",
    "from fcmi_parse_results_recalibration import compute_bins, idx_bins, interpolate_nan, compute_acc\n",
    "\n",
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(1-cur_mask)) + (1-cur_mask)\n",
    "\n",
    "#temp = TemperatureScaling()\n",
    "temp = HistogramBinning(mode='equal_freq',n_bins=15)\n",
    "#temp = HistogramBinning(mode='equal_width',n_bins=15)\n",
    "#temp.fit(cur_preds_recab.numpy(), cur_labels_recab.numpy())\n",
    "#temp.fit(cur_preds_recab.softmax(1).numpy(), cur_labels_recab.numpy())\n",
    "#temp.fit(cur_preds_recab.numpy(), cur_labels_recab.numpy())\n",
    "temp.fit(cur_preds[train_idx].numpy(), cur_labels[train_idx].numpy())\n",
    "confidences = torch.tensor(temp.predict_proba(cur_preds[train_idx]))\n",
    "conf_test = torch.tensor(temp.predict_proba(cur_preds[test_idx]))\n",
    "#cur_preds[train_idx] = torch.tensor(temp.predict_proba(cur_preds[train_idx]))\n",
    "#confidences = cur_preds[train_idx]\n",
    "\n",
    "print((confidences.argmax(1) == cur_labels[train_idx]).float().mean())\n",
    "print((conf_test.argmax(1) == cur_labels[test_idx]).float().mean())\n",
    "#(cur_preds[train_idx].argmax(1) == cur_labels[train_idx]).float().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 483,
   "id": "27bdbea0-6fba-4ae8-a490-ac3f37b0a7f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0.],\n",
       "        [1., 0.],\n",
       "        [1., 0.],\n",
       "        ...,\n",
       "        [1., 0.],\n",
       "        [0., 1.],\n",
       "        [1., 0.]], dtype=torch.float64)"
      ]
     },
     "execution_count": 483,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confidences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 484,
   "id": "fb7620fa-ee4e-4f03-a4a8-f5b9d0c5a6ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4668, dtype=torch.float64)"
      ]
     },
     "execution_count": 484,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics(confidences.max(1).values, cur_labels[train_idx]) - metrics(conf_test.max(1).values, cur_labels[train_idx])\n",
    "#metrics(conf_test.max(1).values, cur_labels[train_idx])\n",
    "metrics(confidences.max(1).values, cur_labels[train_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 369,
   "id": "106200be-565b-446c-8634-1c779737c2cf",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_54708/1523883688.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassification\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBinaryCalibrationError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mconfidences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconfidences\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mconfidences\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcur_labels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mconfidences\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcur_labels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m## MEMO: Reverse prob. for label y=0.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
     ]
    }
   ],
   "source": [
    "#confidences.softmax(1).max(1).values\n",
    "from torchmetrics.classification import BinaryCalibrationError\n",
    "\n",
    "#confidences = confidences.max(1).values\n",
    "#confidences[cur_labels[train_idx]==0] = 1 - confidences[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 542,
   "id": "c259ef9c-2f00-4efb-838f-80591676e723",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "binned_stat = scipy.stats.binned_statistic(x=cur_preds[train_idx][:, 1].numpy(), values=np.equal(1, cur_labels[train_idx].numpy()), statistic='mean',bins=15, range=[0,1])\n",
    "prob_class_1 = interpolate_nan(binned_stat.statistic)\n",
    "binning = binned_stat.bin_edges\n",
    "\n",
    "#binning = np.quantile(cur_preds[train_idx][:, 1].numpy(), q=np.linspace(0, 1, 15 + 1))\n",
    "#digitized = np.digitize(cur_preds[train_idx][:, 1].numpy(), bins=binning)\n",
    "#digitized[digitized == len(binning)] = len(binning) - 1  # include rightmost edge in partition\n",
    "#prob_class_1 = interpolate_nan(np.array([cur_labels[train_idx].numpy()[digitized == i].mean() for i in range(1, len(binning))]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 543,
   "id": "8a44843d-5bb1-4df4-8bc8-5ae3afc52a3f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0769, 0.1538, 0.2308, 0.3077, 0.3846, 0.4615, 0.5385, 0.6154,\n",
       "        0.6923, 0.7692, 0.8462, 0.9231, 1.0000, 1.0000])"
      ]
     },
     "execution_count": 543,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#interpolate_nan(np.array(prob_class_1))\n",
    "prob_class_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 544,
   "id": "f056bb6e-03ed-4870-8a83-3a18a5b3125b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4870)"
      ]
     },
     "execution_count": 544,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "digitized = np.digitize(cur_preds[train_idx][:, 1].numpy(), bins=binning)\n",
    "digitized[digitized == len(binning)] = len(binning) - 1\n",
    "p1 = np.array([prob_class_1[j] for j in (digitized - 1)])\n",
    "p1 = np.where(np.isfinite(p1), p1, cur_preds[train_idx][:, 1].numpy())\n",
    "confidences = torch.tensor(np.column_stack([1 - p1, p1]))\n",
    "metrics(confidences.max(1).values, cur_labels[train_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 532,
   "id": "15bfc367-7060-42b5-92de-c48705c6516d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#confidences = confidences.max(1).values\n",
    "#metrics(confidences, cur_labels[train_idx])\n",
    "#confidences[cur_labels[train_idx]==0] = 1 - confidences[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0.\n",
    "#compute_bins(num_bins=15, confidences=confidences, method='uniform')\n",
    "#metrics(confidences, cur_labels[train_idx]==0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 522,
   "id": "158b6c9a-6e2c-4a87-9896-7c6b2360c666",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 0, 0,  ..., 0, 1, 0])"
      ]
     },
     "execution_count": 522,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#metrics(confidences,cur_labels[train_idx])\n",
    "cur_labels[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 545,
   "id": "71f38f2f-5695-4c0e-bb7c-4a2a93301c28",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.5033)"
      ]
     },
     "execution_count": 545,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confidences = confidences.max(1).values\n",
    "confidences[cur_labels[train_idx]==0] = 1 - confidences[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0.\n",
    "#n_bins = compute_bins(num_bins=15, confidences=confidences, method='quantile')\n",
    "n_bins = compute_bins(num_bins=15, confidences=confidences, method='uniform')\n",
    "\n",
    "conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "label_bin = torch.zeros(len(n_bins), device=cur_labels.device, dtype=cur_labels.dtype)\n",
    "idx = idx_bins(confidences, n_bins)\n",
    "\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels)\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 513,
   "id": "e66973ec-fca8-485d-a3aa-ec24e3e8711d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_54708/2065498586.py:13: RuntimeWarning: invalid value encountered in true_divide\n",
      "  bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.4842)"
      ]
     },
     "execution_count": 513,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confidences = cur_preds[train_idx].max(1).values\n",
    "confidences[cur_labels[train_idx]==0] = 1 - confidences[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0.\n",
    "n_bins = compute_bins(num_bins=15, confidences=confidences, method='quantile')\n",
    "#n_bins = compute_bins(num_bins=15, confidences=confidences, method='uniform')\n",
    "\n",
    "conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "label_bin = torch.zeros(len(n_bins), device=cur_labels.device, dtype=cur_labels.dtype)\n",
    "\n",
    "idx = idx_bins(confidences, n_bins)\n",
    "bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "bin_true = (torch.bincount(idx, weights=cur_labels[train_idx], minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "\n",
    "confidence = bin_mean[idx]\n",
    "\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels)\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 516,
   "id": "72c5c29b-e274-40a1-b39d-c390cc5e6a2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6943, 1.0000,\n",
       "        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])"
      ]
     },
     "execution_count": 516,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bin_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 517,
   "id": "e426fcc9-d021-4686-9c32-81fada7348ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 5,  5,  5,  ...,  5, 14,  5])"
      ]
     },
     "execution_count": 517,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "e124496e-869e-4975-b7aa-69780d8978f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_54708/281065693.py:15: RuntimeWarning: invalid value encountered in true_divide\n",
      "  bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.4830, dtype=torch.float64)"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#confidences = confidences.softmax(1).max(1).values\n",
    "confidences = confidences.max(1).values\n",
    "confidences[cur_labels[train_idx]==0] = 1 - confidences[cur_labels[train_idx]==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "n_bins = compute_bins(num_bins=10, confidences=confidences, method='quantile')\n",
    "#n_bins = compute_bins(num_bins=10, confidences=confidences, method='uniform')\n",
    "conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "label_bin = torch.zeros(len(n_bins), device=cur_labels.device, dtype=cur_labels.dtype)\n",
    "idx = idx_bins(confidences, n_bins)\n",
    "\n",
    "bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "bin_true = (torch.bincount(idx, weights=cur_labels[train_idx], minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "#bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "confidence = bin_mean[idx]\n",
    "\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels)\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ba00964d-cee7-4ac6-977d-d2d296465a20",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_54708/3259621836.py:12: RuntimeWarning: invalid value encountered in true_divide\n",
      "  bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.2513)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confidences = cur_preds.softmax(1).max(1).values\n",
    "confidences[cur_labels==0] = 1 - confidences[cur_labels==0] ## MEMO: Reverse prob. for label y=0.\n",
    "n_bins = compute_bins(num_bins=15, confidences=confidences, method='quantile')\n",
    "conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "label_bin = torch.zeros(len(n_bins), device=cur_labels.device, dtype=cur_labels.dtype)\n",
    "idx = idx_bins(confidences, n_bins)\n",
    "\n",
    "bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "bin_true = (torch.bincount(idx, weights=cur_labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "\n",
    "confidence = bin_mean[idx]\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels)\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "6f3403b9-73f1-4b46-ab50-b3ea0871aa66",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "#confidences = cur_preds.softmax(1).max(1).values\n",
    "confidences = torch.tensor(temp.predict_proba(cur_preds))\n",
    "confidences = confidences.softmax(1).max(1).values\n",
    "confidences[cur_labels==0] = 1 - confidences[cur_labels==0] ## MEMO: Reverse prob. for label y=0.\n",
    "#n_bins = compute_bins(num_bins=15, confidences=confidences, method='quantile')\n",
    "n_bins = compute_bins(num_bins=15, confidences=confidences, method='uniform')\n",
    "\n",
    "conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "label_bin = torch.zeros(len(n_bins), device=cur_labels.device, dtype=cur_labels.dtype)\n",
    "idx = idx_bins(confidences, n_bins)\n",
    "\n",
    "#bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "#bin_true = (torch.bincount(idx, weights=cur_labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "#bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "#bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "307402a2-9f9a-4093-abc8-c5ef54f1816d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "f8ef6558-513b-453d-8287-c1beec1552b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2691, dtype=torch.float64)"
      ]
     },
     "execution_count": 166,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#confidence = bin_mean[idx]\n",
    "count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences))\n",
    "conf_bin.scatter_add_(dim=0, index=idx, src=confidences)\n",
    "conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "prop_bin = count_bin / count_bin.sum()\n",
    "\n",
    "label_bin.scatter_add_(dim=0, index=idx, src=cur_labels)\n",
    "label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3069c000-3eb6-4a19-b550-a3cef0cfe98f",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'cur_labels' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mid\u001b[39m \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marange(\u001b[38;5;241m0\u001b[39m,\u001b[38;5;28mlen\u001b[39m(torch\u001b[38;5;241m.\u001b[39munique(\u001b[43mcur_labels\u001b[49m)))\n\u001b[1;32m      2\u001b[0m a \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros_like(cur_preds)\n\u001b[1;32m      3\u001b[0m a[:,\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m cur_preds[:,\u001b[38;5;28mid\u001b[39m\u001b[38;5;241m!=\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m1\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'cur_labels' is not defined"
     ]
    }
   ],
   "source": [
    "id = np.arange(0,len(torch.unique(cur_labels)))\n",
    "a = torch.zeros_like(cur_preds)\n",
    "a[:,0] = cur_preds[:,id!=1].sum(1)\n",
    "a[:,1] = cur_preds[:,id==1].sum(1)\n",
    "#cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['pred'].softmax(1)\n",
    "    #cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)['label_0'] ## (2*n, num_classes)\n",
    "    #a = torch.zeros_like(cur_preds)\n",
    "    #a[:,0] = cur_preds[:,id!=1].sum(1)\n",
    "    #a[:,1] = cur_preds[:,id==1].sum(1)\n",
    "    #cur_preds = a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f8501dc8-79bc-4de7-8e69-d95aabfea72a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#cur_preds[:,id==1].sum(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0779a3bf-1009-4c77-8af1-6713012af262",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.int64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#labels = [y for x, y in all_examples]\n",
    "#labels = torch.tensor(labels).long()\n",
    "#indices = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "#cur_preds[indices]\n",
    "labels = torch.randint(0, 9, (150,)).long()\n",
    "labels[labels!=3] = 0\n",
    "labels[labels==3] = 1\n",
    "labels.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "27131391-624c-466b-8d4b-b0d3db56ef73",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8731, 0.1269],\n",
       "        [0.8270, 0.1730],\n",
       "        [0.9688, 0.0312],\n",
       "        [0.9041, 0.0959],\n",
       "        [0.9417, 0.0583],\n",
       "        [0.6835, 0.3165],\n",
       "        [0.9645, 0.0355],\n",
       "        [0.9532, 0.0468],\n",
       "        [0.9521, 0.0479],\n",
       "        [0.9898, 0.0102],\n",
       "        [0.9458, 0.0542],\n",
       "        [0.9833, 0.0167],\n",
       "        [0.9540, 0.0460],\n",
       "        [0.8734, 0.1266],\n",
       "        [0.9620, 0.0380],\n",
       "        [0.9759, 0.0241],\n",
       "        [0.9441, 0.0559],\n",
       "        [0.9222, 0.0778],\n",
       "        [0.9132, 0.0868],\n",
       "        [0.8342, 0.1658],\n",
       "        [0.8936, 0.1064],\n",
       "        [0.9361, 0.0639],\n",
       "        [0.7249, 0.2751],\n",
       "        [0.7663, 0.2337],\n",
       "        [0.9051, 0.0949],\n",
       "        [0.7986, 0.2014],\n",
       "        [0.9261, 0.0739],\n",
       "        [0.7916, 0.2084],\n",
       "        [0.9616, 0.0384],\n",
       "        [0.9022, 0.0978],\n",
       "        [0.9326, 0.0674],\n",
       "        [0.9050, 0.0950],\n",
       "        [0.8834, 0.1166],\n",
       "        [0.8604, 0.1396],\n",
       "        [0.8346, 0.1654],\n",
       "        [0.9265, 0.0735],\n",
       "        [0.9578, 0.0422],\n",
       "        [0.9530, 0.0470],\n",
       "        [0.9926, 0.0074],\n",
       "        [0.9286, 0.0714],\n",
       "        [0.9847, 0.0153],\n",
       "        [0.7515, 0.2485],\n",
       "        [0.9771, 0.0229],\n",
       "        [0.9410, 0.0590],\n",
       "        [0.9929, 0.0071],\n",
       "        [0.8000, 0.2000],\n",
       "        [0.7604, 0.2396],\n",
       "        [0.9252, 0.0748],\n",
       "        [0.9073, 0.0927],\n",
       "        [0.9540, 0.0460],\n",
       "        [0.8380, 0.1620],\n",
       "        [0.9219, 0.0781],\n",
       "        [0.8850, 0.1150],\n",
       "        [0.8019, 0.1981],\n",
       "        [0.9176, 0.0824],\n",
       "        [0.9837, 0.0163],\n",
       "        [0.8799, 0.1201],\n",
       "        [0.9364, 0.0636],\n",
       "        [0.7703, 0.2297],\n",
       "        [0.9715, 0.0285],\n",
       "        [0.8703, 0.1297],\n",
       "        [0.9385, 0.0615],\n",
       "        [0.9744, 0.0256],\n",
       "        [0.9304, 0.0696],\n",
       "        [0.9436, 0.0564],\n",
       "        [0.9490, 0.0510],\n",
       "        [0.8534, 0.1466],\n",
       "        [0.9617, 0.0383],\n",
       "        [0.9873, 0.0127],\n",
       "        [0.9624, 0.0376],\n",
       "        [0.8758, 0.1242],\n",
       "        [0.9593, 0.0407],\n",
       "        [0.9301, 0.0699],\n",
       "        [0.9571, 0.0429],\n",
       "        [0.9490, 0.0510],\n",
       "        [0.7387, 0.2613],\n",
       "        [0.9948, 0.0052],\n",
       "        [0.9248, 0.0752],\n",
       "        [0.9513, 0.0487],\n",
       "        [0.9817, 0.0183],\n",
       "        [0.9105, 0.0895],\n",
       "        [0.9343, 0.0657],\n",
       "        [0.9364, 0.0636],\n",
       "        [0.9703, 0.0297],\n",
       "        [0.8346, 0.1654],\n",
       "        [0.9702, 0.0298],\n",
       "        [0.9757, 0.0243],\n",
       "        [0.7207, 0.2793],\n",
       "        [0.9832, 0.0168],\n",
       "        [0.8860, 0.1140],\n",
       "        [0.9799, 0.0201],\n",
       "        [0.8636, 0.1364],\n",
       "        [0.9651, 0.0349],\n",
       "        [0.9766, 0.0234],\n",
       "        [0.9680, 0.0320],\n",
       "        [0.9690, 0.0310],\n",
       "        [0.7665, 0.2335],\n",
       "        [0.8801, 0.1199],\n",
       "        [0.9362, 0.0638],\n",
       "        [0.9304, 0.0696],\n",
       "        [0.9875, 0.0125],\n",
       "        [0.9432, 0.0568],\n",
       "        [0.9418, 0.0582],\n",
       "        [0.9833, 0.0167],\n",
       "        [0.8611, 0.1389],\n",
       "        [0.9788, 0.0212],\n",
       "        [0.9471, 0.0529],\n",
       "        [0.9956, 0.0044],\n",
       "        [0.9845, 0.0155],\n",
       "        [0.9746, 0.0254],\n",
       "        [0.8722, 0.1278],\n",
       "        [0.9443, 0.0557],\n",
       "        [0.7335, 0.2665],\n",
       "        [0.7740, 0.2260],\n",
       "        [0.8024, 0.1976],\n",
       "        [0.7925, 0.2075],\n",
       "        [0.9875, 0.0125],\n",
       "        [0.9636, 0.0364],\n",
       "        [0.8971, 0.1029],\n",
       "        [0.9228, 0.0772],\n",
       "        [0.8937, 0.1063],\n",
       "        [0.9061, 0.0939],\n",
       "        [0.7912, 0.2088],\n",
       "        [0.8750, 0.1250],\n",
       "        [0.8591, 0.1409],\n",
       "        [0.9493, 0.0507],\n",
       "        [0.9308, 0.0692],\n",
       "        [0.9779, 0.0221],\n",
       "        [0.7896, 0.2104],\n",
       "        [0.7984, 0.2016],\n",
       "        [0.9225, 0.0775],\n",
       "        [0.9411, 0.0589],\n",
       "        [0.9203, 0.0797],\n",
       "        [0.9921, 0.0079],\n",
       "        [0.8285, 0.1715],\n",
       "        [0.9693, 0.0307],\n",
       "        [0.7100, 0.2900],\n",
       "        [0.9864, 0.0136],\n",
       "        [0.9290, 0.0710],\n",
       "        [0.9644, 0.0356],\n",
       "        [0.9255, 0.0745],\n",
       "        [0.9349, 0.0651],\n",
       "        [0.9715, 0.0285],\n",
       "        [0.9429, 0.0571],\n",
       "        [0.8998, 0.1002],\n",
       "        [0.9359, 0.0641],\n",
       "        [0.9291, 0.0709],\n",
       "        [0.8869, 0.1131],\n",
       "        [0.6795, 0.3205],\n",
       "        [0.8438, 0.1562]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = np.arange(0,10)\n",
    "p = torch.randn(150, 10).softmax(1)\n",
    "a = torch.zeros(150,2)\n",
    "a[:,0] = p[:,idx!=3].sum(1)\n",
    "a[:,1] = p[:,idx==3].sum(1)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 745,
   "id": "ca153a28-4d81-4323-b697-e654dafe8a23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 745,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#saved_data['args'].dataset == 'cifar'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "bb2b0763-b511-46f8-ae73-1353e08df47d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 1.0000, 1.0000, 1.0000, 1.0000], dtype=torch.float64)"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_bins = int(n ** (1/3)) ## l1\n",
    "conf = cur_preds[train_idx].softmax(1).max(1).values\n",
    "n_bins = torch.tensor(np.quantile(conf, np.linspace(0, 1, n_bins + 1)))\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "n_bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "271899f0-591f-4864-a8c6-f90a4eee86ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "targets = torch.tensor([y for x, y in all_examples])\n",
    "indices = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "y = targets[indices]\n",
    "\n",
    "if not torch.all((cur_preds[indices] >= 0) * (cur_preds[indices] <= 1)):\n",
    "    p = cur_preds[indices].softmax(1)\n",
    "confidences, _ = p.max(dim=1)\n",
    "confidences[y==0] = 1 - confidences[y==0]\n",
    "((y - confidences) ** 2).mean()\n",
    "#targets[indices] - torch.max(softmax(cur_preds[indices],1),1).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b3f51c74-91bf-424c-80c3-a8b160f1eb10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.08350420864501443\n"
     ]
    }
   ],
   "source": [
    "num_classes = 2\n",
    "id = 50\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "    ps[i] = num_classes * ps[i][0] + ps[i][1] ## ???　--> supersampleに対する予測パターンのうち，どこに値するかを計算(２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り; 予測が(0,1)なら，2*0+1=1番目のパターン)\n",
    "    ps[i] = ps[i].item()\n",
    "\n",
    "## estimation of Hellström et al.\n",
    "prob = np.zeros((2, 4))\n",
    "for a, b in zip(ms, ps):\n",
    "    ## 30個のモデルのうち， (mask,予測パターン)に関する同時確率．例えば，30個のモデルのうち，(mask=0,予測パターン=2)のモデルが4個あった場合の確率は， 4/30．\n",
    "    prob[a,b] += 1 / len(ms)\n",
    "    #1.0/len(ms) ## {train or val}かどうかで(0,1)で条件付確率 (e.g., val時（ms=1）, supersampleに対する予測は（1,0） --> prob[0,3]; ２値の場合，予測は(0,0), (0,1), (1,0), (1,1)の4通り)\n",
    "\n",
    "pa = np.sum(prob, axis=1) ## 各予測確率パターンに対する周辺化\n",
    "pb = np.sum(prob, axis=0) ## maskに対する周辺化 (TODO: Check this corresponds to the 0-1 loss --> maybe, this is not 0-1 loss; overall, this estimation is based on Cor. 2 of H.)\n",
    "mi = 0\n",
    "for a in range(2):\n",
    "    for b in range(4):\n",
    "        if prob[a,b] < 1e-9:\n",
    "            continue\n",
    "        mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))\n",
    "print(mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "606fee8b-415e-496e-b8a0-f145f8ea4216",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08350420864501451"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Our estimation\n",
    "id = 50\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "#ls = [l[2*id:2*id+2] for l in labels]\n",
    "#loss = 0.\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化)) (= 0-1 loss)\n",
    "    #ls[i] = (ps[i] != ls[i]).sum()\n",
    "    #ls[i] = ls[i].item()\n",
    "\n",
    "#mutual_info_classif(ps.reshape(-1,2), ms, discrete_features=[True, True]).sum() ## discrete_features detects the dtype of \"X (= loss in this case)\" (There is no-problem on My MacBook Pro)\n",
    "mutual_info_classif(torch.concat(ps).reshape(-1,2).numpy(), ms, discrete_features=[True, True]).sum() ## discrete_features detects the dtype of \"X (= loss in this case)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "0f8b2249-15bf-4c6a-9e0f-f3c56b10d679",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dc4dee0-be8d-467c-bbdf-f80f5afdebdf",
   "metadata": {},
   "source": [
    "# Evaluation of our bounds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c29fbb63-73aa-45c4-8aad-712eb35718ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 2\n",
    "\n",
    "def calc_l2_loss(preds, labels):\n",
    "    loss = torch.zeros(1, 2)\n",
    "    for i in range(2):\n",
    "        if labels[0][i].item() == 0:\n",
    "            loss[0][i] = ((1-preds[0][i]) - labels[0][i]) ** 2\n",
    "        else:\n",
    "            loss[0][i] = (preds[0][i] - labels[0][i]) ** 2\n",
    "    \n",
    "    return loss\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d35269ff-4988-41d9-bd11-734e58083fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KernelDensity\n",
    "import numpy as np\n",
    "\n",
    "def mutual_information_kernel_density(X, y, bandwidth=1., num_points=100):\n",
    "    X_y = np.column_stack((X, y))\n",
    "    # カーネル密度推定\n",
    "    kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "\n",
    "    kde_X.fit(X)\n",
    "    kde_Y.fit(y.reshape(-1, 1))\n",
    "    kde_X_Y.fit(X_y)\n",
    "\n",
    "    # 連続値Xおよび離散値yに対する密度を取得\n",
    "    density_X = np.exp(kde_X.score_samples(X))\n",
    "    density_Y = np.exp(kde_Y.score_samples(y.reshape(-1, 1)))\n",
    "    density_X_Y = np.exp(kde_X_Y.score_samples(X_y))\n",
    "    #density_X = np.exp(kde_X.score(X))\n",
    "    #density_Y = np.exp(kde_Y.score(y.reshape(-1, 1)))\n",
    "    #density_X_Y = np.exp(kde_X_Y.score(X_y))\n",
    "    \n",
    "    #mi = (density_X_Y * np.log(density_X_Y) - density_X_Y * np.log(density_X*density_Y)).sum()\n",
    "    mi = (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()\n",
    "\n",
    "    return max(0.0, mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4473778d-2758-492c-88dc-b6d52d9d819d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "c22103f8-8620-4995-ba36-d615f4768c89",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0c5f9f5f9584716ab3911938b113ae2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "915.0480970663328\n",
      "915.0480970663334\n"
     ]
    }
   ],
   "source": [
    "bound = 0.0\n",
    "list_of_mis = []\n",
    "#metrics = ECE(num_classes=num_classes, n_bins=4, norm='l2') ## norm='l1' (ECE); norm='l2' (RMSCE); bins=15 (default).\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = [m[idx] for m in masks]\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    loss = []\n",
    "    for i in range(len(ps)):\n",
    "        #ps[i] = metrics(ps[i],ls[i]) ** 2\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        #l = (ls[i] - ps[i])\n",
    "        l = ps[i]\n",
    "\n",
    "        #l = torch.zeros(1,2)\n",
    "        #for j in range(2):\n",
    "        #    if ls[i][j] == 0:\n",
    "        #        l[0][j] = ((1-ps[i][j]) - ls[i][j])**2\n",
    "        #    else:\n",
    "        #        l[0][j] = (ps[i][j] - ls[i][j])**2\n",
    "            \n",
    "        #loss.append((ps[i] - ls[i])**2)\n",
    "        loss.append(l)\n",
    "    \n",
    "    #ps = np.array(ps).reshape(-1,1)\n",
    "    #ps = torch.concat(ps).reshape(-1,2).numpy()\n",
    "    #ps = torch.concat(loss).numpy()\n",
    "    \n",
    "    ps = torch.concat(loss).reshape(-1,2).numpy()\n",
    "    #cur_mi = mutual_info_classif(ps, ms, discrete_features=[False, False]).sum()\n",
    "    cur_mi = mutual_info_classif(ps, ms).sum()\n",
    "    list_of_mis.append(cur_mi)\n",
    "    #bound += (cur_mi + np.log(3))\n",
    "    bound += cur_mi\n",
    "\n",
    "#bound *= 4/n\n",
    "#bound = np.sqrt(4/n*(bound + np.log(3)))\n",
    "print(bound)\n",
    "print(np.array(list_of_mis).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a45c3df-acde-4285-a192-620e7ac559bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "id": "a9b5d9ee-0084-4cd4-9686-afa0ae52b439",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1])"
      ]
     },
     "execution_count": 163,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#p = torch.max(softmax(preds[0][0:2],1),1).values\n",
    "#l = torch.max(softmax(preds[0][0:2],1),1).indices #.values\n",
    "#p[l==0] = 1 - p[l==0]\n",
    "#p\n",
    "labels[0][0:2]\n",
    "#masks[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c25ddea7-8eed-4099-a9cd-342c765d30f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b9c08784ed34791a7e8a952732c186d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/75 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.241940467751646\n"
     ]
    }
   ],
   "source": [
    "## BinごとにMIを推定\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "#list_of_mis = []\n",
    "list_of_mis = np.zeros(len(n_bins))\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    #bs = np.array([b[2*idx:2*idx+2] for b in bins])\n",
    "    loss = []\n",
    "    #loss = torch.zeros(len(n_bins),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        l = (ls[i] - ps[i])\n",
    "        loss.append(l)\n",
    "        #loss[bs[i]] = (ls[i] - ps[i])\n",
    "    \n",
    "    ps = torch.concat(loss).reshape(-1,2).numpy()\n",
    "    for i in np.unique(bs):\n",
    "        ps_bin = ps[bs == i]\n",
    "        ms_bin = ms[bs == i]\n",
    "        #print(i)\n",
    "        if len(ps_bin) < 3:\n",
    "            #list_of_mis[i] += 0\n",
    "            pass\n",
    "        else:\n",
    "            cur_mi = mutual_info_classif(ps_bin, ms_bin, discrete_features=[False, False]).sum()\n",
    "            list_of_mis[i] += cur_mi\n",
    "    #cur_mi = mutual_info_classif(ps, ms, discrete_features=[False, False]).sum()\n",
    "    #list_of_mis.append(cur_mi)\n",
    "    #bound += cur_mi\n",
    "B = len(n_bins)\n",
    "#bound = 2*(list_of_mis.sum() + 2*B*np.log(2)) / np.sqrt(n) + 1/B\n",
    "bound = np.sqrt(4*(list_of_mis.sum() + B*np.log(3)) / n)  + 1/B + np.sqrt(B/(4*n))\n",
    "\n",
    "#bound *= 4/n\n",
    "#bound = np.sqrt(4/n*(bound + np.log(3)))\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "5690e701-aab6-417e-80b9-1a545f50adfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mutual_information_kde_bin(loss_bin, masks, bandwidth=1.):\n",
    "    # prepare data\n",
    "    n_bins = len(loss_bin)\n",
    "    X_y = np.array([np.column_stack((loss_bin[i], ms)) for i in range(n_bins)]).reshape(-1,3)\n",
    "    l_mat = loss_bin.reshape(-1,2)\n",
    "    ms_mat = np.tile(ms, n_bins).reshape(-1,1)\n",
    "    \n",
    "    # カーネル密度推定\n",
    "    kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "\n",
    "    kde_X.fit(l_mat)\n",
    "    kde_Y.fit(ms_mat)\n",
    "    kde_X_Y.fit(X_y)\n",
    "\n",
    "    # 連続値Xおよび離散値yに対する密度を取得\n",
    "    density_X = np.exp(kde_X.score_samples(l_mat))\n",
    "    density_Y = np.exp(kde_Y.score_samples(ms_mat))\n",
    "    density_X_Y = np.exp(kde_X_Y.score_samples(X_y))\n",
    "    mi = (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()\n",
    "\n",
    "    return max(0.0, mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "637b0a66-20ff-4b2d-b0fe-bd30adc1cacf",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "operands could not be broadcast together with shapes (150,) (210,) ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[39], line 45\u001b[0m\n\u001b[1;32m     42\u001b[0m density_X_Y \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mexp(kde_X_Y\u001b[38;5;241m.\u001b[39mscore_samples(X_y))\n\u001b[1;32m     44\u001b[0m \u001b[38;5;66;03m#(density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()\u001b[39;00m\n\u001b[0;32m---> 45\u001b[0m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;241m0\u001b[39m, (density_X_Y \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mlog(density_X_Y \u001b[38;5;241m/\u001b[39m (\u001b[43mdensity_X\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdensity_Y\u001b[49m)))\u001b[38;5;241m.\u001b[39msum())\n",
      "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (150,) (210,) "
     ]
    }
   ],
   "source": [
    "\n",
    "#loss_bin.shape\n",
    "X_y = np.array([np.column_stack((loss_bin[i], ms)) for i in range(len(n_bins))]).reshape(-1,3)\n",
    "bandwidth = 1.\n",
    "l_mat = loss_bin.reshape(-1,2)\n",
    "ms_mat = np.tile(ms, 7).reshape(-1,1)\n",
    "\n",
    "# カーネル密度推定\n",
    "kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "\n",
    "kde_X.fit(l_mat)\n",
    "kde_Y.fit(ms_mat)\n",
    "kde_X_Y.fit(X_y)\n",
    "\n",
    "density_X = np.exp(kde_X.score_samples(l_mat))\n",
    "density_Y = np.exp(kde_Y.score_samples(ms_mat))\n",
    "density_X_Y = np.exp(kde_X_Y.score_samples(X_y))\n",
    "\n",
    "#(density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()\n",
    "max(0, (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "f5044e17-bd80-46bb-98f9-706c7020e63f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84c7138762414c89b686d82da7fc4d1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "88.39119518114852\n"
     ]
    }
   ],
   "source": [
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "list_of_mis = []\n",
    "#list_of_mis = np.zeros(len(n_bins))\n",
    "#loss = torch.zeros(len(n_bins), n, 2)\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        #ps[i] = torch.max(ps[i], dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        #loss.append(ps[i])\n",
    "        l = (ls[i] - ps[i])\n",
    "        loss.append(l)\n",
    "        loss_bin[bs[i]][i] = l\n",
    "    #ps = torch.concat(loss).reshape(-1,2)\n",
    "    #loss_bin[bs] = ps\n",
    "    #cur_mi = np.array([mutual_info_classif(loss_bin[i], ms, discrete_features=[False, False]).sum() for i in range(len(n_bins))]).sum() #/ len(ps)\n",
    "    #cur_mi = np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(len(n_bins))]).sum() #/ len(ps)\n",
    "    #cur_mi = mutual_information_kde_bin(loss_bin, ms, bandwidth=\"scott\")\n",
    "    cur_mi = mutual_info_classif(loss_bin.reshape(-1,2), np.tile(ms, len(n_bins)), discrete_features=[False, False]).sum()\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "\n",
    "#B = len(n_bins)\n",
    "#bound = np.sqrt((2*B*(bound + np.log(B)))/n)\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "d5575c95-094c-432e-932e-e76b0f1a0353",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(320, 20)"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.tile(loss_bin,2)\n",
    "np.tile(loss_bin.reshape(-1,2), 10).shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "5aa4852f-315e-4767-9fad-0f6e0ae02c1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(len(n_bins))]).sum() \n",
    "#loss_bin\n",
    "#np.median(loss_bin.reshape(-1,2))\n",
    "#mutual_info_classif(loss_bin.reshape(-1,2), np.tile(ms,5), discrete_features=[False, False]).sum()\n",
    "#loss_bin.reshape(-1,2).shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 867,
   "id": "09a572ce-b82a-4615-8f0b-22cef7ee7174",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   6,   11,   19,   34,   59,  103,  182,  321,  567, 1000])"
      ]
     },
     "execution_count": 867,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start, end, num_candidates = 6, 10**3, 10\n",
    "int_candidates = np.ceil(np.logspace(np.log10(start), np.log10(end), num=num_candidates)).astype(int) ## bin candidates\n",
    "int_candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 873,
   "id": "7ac89ea4-f98f-45ef-bae4-a6aa3b08d1a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   4,    6,    9,   15,   21,   44,   96,  210,  458, 1000])"
      ]
     },
     "execution_count": 873,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#int_candidates\n",
    "ns = np.array([75, 250, 1000, 4000])\n",
    "#ns = np.array([500, 1000, 5000, 20000])\n",
    "opt_bin = np.floor((ns ** (1/3)))\n",
    "start, end, num_candidates = max(opt_bin)+5, 10**3, (10 - len(opt_bin))\n",
    "int_candidates = np.ceil(np.logspace(np.log10(start), np.log10(end), num=num_candidates)).astype(int) ## bin candidates\n",
    "np.concatenate([opt_bin, int_candidates]).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 835,
   "id": "365417c2-e376-4b03-b518-5b3414099e20",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a82c4e9ee53a4ba49c0112ed94c5d823",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.059849897740922316\n"
     ]
    }
   ],
   "source": [
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "list_of_mis = []\n",
    "#list_of_mis = np.zeros(len(n_bins))\n",
    "#loss = torch.zeros(len(n_bins), n, 2)\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        #ps[i] = torch.max(ps[i], dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        loss.append(ps[i])\n",
    "        #l = (ls[i] - ps[i])\n",
    "        #loss.append(l)\n",
    "    ps = torch.concat(loss).reshape(-1,2)\n",
    "    #loss_bin[bs] = ps\n",
    "    #cur_mi = np.array([mutual_info_classif(loss_bin[i], ms, discrete_features=[False, False]).sum() for i in range(len(n_bins))]).sum() #/ len(ps)\n",
    "    cur_mi = mutual_information_kernel_density(ps, ms)\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "\n",
    "#B = len(n_bins)\n",
    "#bound = np.sqrt((2*B*(bound + np.log(B)))/n)\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcb06919-4e3e-412e-a9bd-bad379fcec5c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 699,
   "id": "97b80e72-5676-46b3-9b88-75a467bbaed2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e06e8fd12bd0419da4e17c0efd3a50e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.07438428329149667\n"
     ]
    }
   ],
   "source": [
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "list_of_mis = []\n",
    "#list_of_mis = np.zeros(len(n_bins))\n",
    "#loss = torch.zeros(len(n_bins), n, 2)\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        ps[i] = recalibration(ps[i], n_bins)\n",
    "        #idx = torch.bucketize(ps[i], n_bins, right=True) - 1\n",
    "        #bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps[i].device) ## the number of samples per bins\n",
    "        #bin_true = torch.bincount(idx, weights=ps[i], minlength=len(n_bins)-1).float().to(ps[i].device) ## the number of samples per bins weighted by labels\n",
    "        #with warnings.catch_warnings():\n",
    "        #    warnings.filterwarnings('ignore')\n",
    "        #    bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        #ps[i] = bin_mean[idx]\n",
    "        l = (ls[i] - ps[i])\n",
    "        loss.append(l)\n",
    "    ps = torch.concat(loss).reshape(-1,2)\n",
    "    loss_bin[bs] = ps\n",
    "    cur_mi = np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(len(n_bins))]).sum() #/ len(ps)\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "\n",
    "#B = len(n_bins)\n",
    "#bound = np.sqrt((2*B*(bound + np.log(B)))/n)\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 707,
   "id": "163afcd2-0161-4571-b9c6-fe127f6bf97b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdf7c2c30ef6414381abf757cd123182",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def idx_bins(confidence, n_bins):\n",
    "    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)\n",
    "    binids -= 1\n",
    "    return torch.tensor(binids)\n",
    "\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "list_of_mis = []\n",
    "#list_of_mis = np.zeros(len(n_bins))\n",
    "#loss = torch.zeros(len(n_bins), n, 2)\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 712,
   "id": "7665cedf-588d-4bb3-aaed-f2465fa650e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 712,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#b = torch.tensor(np.quantile(ps[0], torch.linspace(0, 1, len(n_bins) + 1)))\n",
    "#b[0], b[-1] = 0., 1.\n",
    "#calc_recalibration(ps[0], b)\n",
    "len(bins)\n",
    "\n",
    "#def calc_recalibration(ps, n_bins):\n",
    "#    #idx = torch.bucketize(ps, n_bins, right=True) - 1\n",
    "#    idx = idx_bins(ps, n_bins)\n",
    "#    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins\n",
    "#    bin_true = torch.bincount(idx, weights=ps, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins weighted by labels\n",
    "#    with warnings.catch_warnings():\n",
    "#        warnings.filterwarnings('ignore')\n",
    "#        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "#    return bin_mean[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 713,
   "id": "017783f8-6e3b-418d-a231-b5ecd7923f0e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dffa3fd7354a4c179178f6f927565ad3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.07438428329149667\n"
     ]
    }
   ],
   "source": [
    "n_bins = int(n ** (1/3)) ## l1\n",
    "n_bins = torch.linspace(0, 1, n_bins + 1) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "\n",
    "bound = 0.0\n",
    "#bound = torch.zeros(len(n_bins))\n",
    "list_of_mis = []\n",
    "#list_of_mis = np.zeros(len(n_bins))\n",
    "#loss = torch.zeros(len(n_bins), n, 2)\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        ps[i] = recalibration(ps[i], n_bins)\n",
    "        #idx = torch.bucketize(ps[i], n_bins, right=True) - 1\n",
    "        #bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps[i].device) ## the number of samples per bins\n",
    "        #bin_true = torch.bincount(idx, weights=ps[i], minlength=len(n_bins)-1).float().to(ps[i].device) ## the number of samples per bins weighted by labels\n",
    "        #with warnings.catch_warnings():\n",
    "        #    warnings.filterwarnings('ignore')\n",
    "        #    bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        #ps[i] = bin_mean[idx]\n",
    "        #l = (ls[i] - ps[i])\n",
    "        #loss.append(l)\n",
    "        loss.append(ps[i])\n",
    "    ps = torch.concat(loss).reshape(-1,2)\n",
    "    loss_bin[bs] = ps\n",
    "    cur_mi = np.array([mutual_information_kernel_density(loss_bin[i], ms) for i in range(len(n_bins))]).sum() #/ len(ps)\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "\n",
    "#B = len(n_bins)\n",
    "#bound = np.sqrt((2*B*(bound + np.log(B)))/n)\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 737,
   "id": "c9ff870b-9ca6-4669-98ef-6bbbc713e5f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_y = np.array([np.column_stack((loss_bin[i], ms)) for i in range(len(n_bins))])\n",
    "bandwidth = 1.\n",
    "kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "\n",
    "kde_X.fit(loss_bin.reshape(-1,2))\n",
    "kde_Y.fit(ms.reshape(-1, 1))\n",
    "kde_X_Y.fit(X_y.reshape(-1,3))\n",
    "\n",
    "density_X = np.exp(kde_X.score_samples(loss_bin.reshape(-1,2)))\n",
    "density_Y = np.exp(kde_Y.score_samples(ms.reshape(-1, 1)))\n",
    "density_X_Y = np.exp(kde_X_Y.score_samples(X_y.reshape(-1,3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 728,
   "id": "1481cd79-86cd-4f19-8cd6-7cb5d63a93b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loss_bin.shape\n",
    "#X_y = np.array([np.column_stack((loss_bin[i], ms)) for i in range(len(n_bins))])\n",
    "\n",
    "def mutual_information_kernel_density(X, y, X_y, bandwidth=1., num_points=100):\n",
    "    #X_y = np.column_stack((X, y))\n",
    "    # カーネル密度推定\n",
    "    kde_X = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "    kde_X_Y = KernelDensity(bandwidth=bandwidth, kernel='gaussian')\n",
    "\n",
    "    kde_X.fit(X)\n",
    "    kde_Y.fit(y.reshape(-1, 1))\n",
    "    kde_X_Y.fit(X_y)\n",
    "\n",
    "    # 連続値Xおよび離散値yに対する密度を取得\n",
    "    density_X = np.exp(kde_X.score_samples(X))\n",
    "    density_Y = np.exp(kde_Y.score_samples(y.reshape(-1, 1)))\n",
    "    density_X_Y = np.exp(kde_X_Y.score_samples(X_y))\n",
    "    #density_X = np.exp(kde_X.score(X))\n",
    "    #density_Y = np.exp(kde_Y.score(y.reshape(-1, 1)))\n",
    "    #density_X_Y = np.exp(kde_X_Y.score(X_y))\n",
    "    \n",
    "    #mi = (density_X_Y * np.log(density_X_Y) - density_X_Y * np.log(density_X*density_Y)).sum()\n",
    "    mi = (density_X_Y * np.log(density_X_Y / (density_X*density_Y))).sum()\n",
    "\n",
    "    return max(0.0, mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "67c0f17c-c15f-4aae-a7e4-41cfd0dbd761",
   "metadata": {},
   "outputs": [],
   "source": [
    "def recalibration(ps, n_bins):\n",
    "    idx = torch.bucketize(ps, n_bins, right=True) - 1\n",
    "    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins\n",
    "    bin_true = torch.bincount(idx, weights=ps, minlength=len(n_bins)-1).float().to(ps.device) ## the number of samples per bins weighted by labels\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.filterwarnings('ignore')\n",
    "        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "    return bin_mean[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "921290c5-28e2-476c-94f9-066c0eed2dcf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5042095125998685\n"
     ]
    }
   ],
   "source": [
    "## Hellstrom\n",
    "\n",
    "def discrete_mi_est(xs, ys, nx=2, ny=2):\n",
    "    prob = np.zeros((nx, ny))\n",
    "    for a, b in zip(xs, ys):\n",
    "        prob[a,b] += 1.0/len(xs)\n",
    "    pa = np.sum(prob, axis=1)\n",
    "    pb = np.sum(prob, axis=0)\n",
    "    mi = 0\n",
    "    for a in range(nx):\n",
    "        for b in range(ny):\n",
    "            if prob[a,b] < 1e-9:\n",
    "                continue\n",
    "            mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))\n",
    "    return max(0.0, mi)\n",
    "\n",
    "bound = 0.0\n",
    "list_of_mis = []\n",
    "for idx in range(n):\n",
    "    ms = [p[idx] for p in masks]\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.argmax(ps[i], dim=1)\n",
    "        #ps[i] = num_classes * ps[i][0] + ps[i][1]\n",
    "        #ps[i] = ps[i].item()\n",
    "    #cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)\n",
    "    ps = torch.concat(ps).reshape(-1,2).numpy()\n",
    "    #cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum() ## masks are discrete\n",
    "    cur_mi = mutual_information_kernel_density(ps, np.array(ms)) ## masks are discrete\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "    #bound += np.sqrt(2 * cur_mi)\n",
    "    #if verbose and idx < 10:\n",
    "    #    print(\"ms:\", ms)\n",
    "    #    print(\"ps:\", ps)\n",
    "    #    print(\"mi:\", cur_mi)\n",
    "    #bound *= 1/n\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "id": "3c052e3f-8dfb-4258-bb13-f4bb8df45a67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.401129627514003e-16"
      ]
     },
     "execution_count": 346,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#mutual_info_classif(ps, ms, discrete_features=[False, False]).sum()\n",
    "mutual_information_kernel_density(ps, np.array(ms), bandwidth=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0a16d890-61cb-4e89-a49e-7cf49f1fdc85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mutual Information: 0.009705067765971381\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "id": "80e1b2a7-64fa-42c1-aceb-91a36f0fdc89",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'reshape'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_13751/2399786298.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmutual_information_kernel_density\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_bin\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_13751/4242211272.py\u001b[0m in \u001b[0;36mmutual_information_kernel_density\u001b[0;34m(X, y, bandwidth, num_points)\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0mkde_X\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m     \u001b[0mkde_Y\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m     \u001b[0mkde_X_Y\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'reshape'"
     ]
    }
   ],
   "source": [
    "mutual_information_kernel_density(loss_bin[3], ms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "923325b9-5bf9-4831-8e4d-6b47e1bec90a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5955af134d1a48c4a81aba997edea920",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/75 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "63.04431836376926\n"
     ]
    }
   ],
   "source": [
    "bound = 0.0\n",
    "list_of_mis = []\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    #loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        #ps[i] = torch.argmax(softmax(ps[i], 1), dim=1)\n",
    "        #ps[i] = torch.argmax(ps[i], dim=1)\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        #ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        loss.append(ps[i])\n",
    "    ps = torch.concat(loss).reshape(-1,2).numpy()\n",
    "    cur_mi = mutual_info_classif(ps, ms, discrete_features=[False, False]).sum()\n",
    "    #cur_mi = mutual_info_classif(ps, ms, discrete_features=[True, True]).sum()\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "print(bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 339,
   "id": "be44f32a-c3de-4b1a-80f7-080a1ddc550d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30, 2)"
      ]
     },
     "execution_count": 339,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "id": "a7c8b6e0-74c2-4865-9e4d-776977233203",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8d46865e48448c4b58024290eb9c358",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/75 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.935048981307149\n"
     ]
    }
   ],
   "source": [
    "bound = 0.0\n",
    "list_of_mis = []\n",
    "for idx in tqdm(range(n)):\n",
    "    ms = np.array([m[idx] for m in masks])\n",
    "    ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "    ls = [l[2*idx:2*idx+2] for l in labels]\n",
    "    bs = np.array([b[idx] for b in bins])\n",
    "    loss = []\n",
    "    loss_bin = torch.zeros(len(n_bins),len(ps),2)\n",
    "    for i in range(len(ps)):\n",
    "        ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "        ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "        loss.append(ps[i])\n",
    "        #l = (ls[i] - ps[i])\n",
    "        #loss.append(l)\n",
    "    #ps = torch.concat(loss).reshape(-1,2)\n",
    "    ps = torch.concat(loss).reshape(-1,2).numpy()\n",
    "    #loss_bin[bs] = ps\n",
    "    #cur_mi = np.array([mutual_info_classif(loss_bin[i], ms, discrete_features=[False, False]).sum() + 2*np.log(2) for i in range(len(n_bins))]).sum() / len(ps)\n",
    "    #cur_mi = np.array([mutual_info_classif(loss_bin[i], ms, discrete_features=[False, False]).sum() for i in range(len(n_bins))]).sum() / len(ps)\n",
    "    cur_mi = mutual_info_classif(ps, ms, discrete_features=[False, False]).sum()\n",
    "    list_of_mis.append(cur_mi)\n",
    "    bound += cur_mi\n",
    "\n",
    "bound = np.sqrt((2*B*(bound + np.log(B)))/n)\n",
    "#bound *= 2/np.sqrt(n)\n",
    "\n",
    "print(bound)\n",
    "#bound_value = np.sqrt(2*B*(bound + np.log(B)) / n)\n",
    "#bound_value_2 = (2*bound + 2*B*np.log(2)) / np.sqrt(n)\n",
    "#print(bound_value, bound_value_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 256,
   "id": "2d646bb4-2a3b-4ef6-b913-a6cc51e01720",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.05569641, 0.23890099, 0.1284248 , 0.81142966, 0.08006381])"
      ]
     },
     "execution_count": 256,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "7ec2fad0-ba64-4a1b-83b3-3aa5695c4263",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0235698194109224"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bound / len(ps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 318,
   "id": "9d2120fe-8a0f-401b-af50-4f3144ae2169",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.89935025051308\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "id": "b463b5c5-0fd9-42b9-89a4-cc625e0a210f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 233,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "id": "4094e10a-6b5b-46b6-8cc6-2906f730d91f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.zeros(len())\n",
    "#ps[]\n",
    "#mi_bin = torch.zeros(len(n_bins), device=torch.tensor(list_of_mis).device, dtype=torch.tensor(list_of_mis).dtype)\n",
    "#mi_bin.scatter_add_(dim=0, index=torch.tensor(bs), src=torch.tensor(list_of_mis))\n",
    "#len(torch.tensor(list_of_mis))\n",
    "#b_idx = np.array(bs)\n",
    "#ps[b_idx == 3]\n",
    "#np.array(ms)[b_idx==2]\n",
    "#len(np.array(ms)[b_idx==2])\n",
    "#ps_bin\n",
    "ms = np.array([p[idx] for p in masks])\n",
    "ps = [p[2*idx:2*idx+2] for p in preds]\n",
    "bs = np.array([b[idx] for b in bins])\n",
    "\n",
    "loss = []\n",
    "for i in range(len(ps)):\n",
    "    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "    ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0]\n",
    "    l = (ls[i] - ps[i])\n",
    "    loss.append(l)\n",
    "\n",
    "ps = torch.concat(loss).reshape(-1,2).numpy()\n",
    "for i in np.unique(bs):\n",
    "    ps_bin = ps[bs == i]\n",
    "    ms_bin = ms[bs == i]\n",
    "#mutual_info_classif(ps[b_idx == 2], np.array(ms)[b_idx==2], discrete_features=[False, False]).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 343,
   "id": "595d31ff-59ab-4a1c-9ab4-c0890e5749f7",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_4988/4239155816.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmutual_info_classif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mps\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbs\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbs\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_features\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_neighbors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;31m#ms[bs==2].shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/utils/_param_validation.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    212\u001b[0m                     )\n\u001b[1;32m    213\u001b[0m                 ):\n\u001b[0;32m--> 214\u001b[0;31m                     \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    215\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0mInvalidParameterError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    216\u001b[0m                 \u001b[0;31m# When the function is just a wrapper around an estimator, we allow\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/feature_selection/_mutual_info.py\u001b[0m in \u001b[0;36mmutual_info_classif\u001b[0;34m(X, y, discrete_features, n_neighbors, copy, random_state)\u001b[0m\n\u001b[1;32m    488\u001b[0m     \"\"\"\n\u001b[1;32m    489\u001b[0m     \u001b[0mcheck_classification_targets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 490\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_estimate_mi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_features\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/feature_selection/_mutual_info.py\u001b[0m in \u001b[0;36m_estimate_mi\u001b[0;34m(X, y, discrete_features, discrete_target, n_neighbors, copy, random_state)\u001b[0m\n\u001b[1;32m    302\u001b[0m         )\n\u001b[1;32m    303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 304\u001b[0;31m     mi = [\n\u001b[0m\u001b[1;32m    305\u001b[0m         \u001b[0m_compute_mi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_feature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_target\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    306\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_feature\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_iterate_columns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/feature_selection/_mutual_info.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    304\u001b[0m     mi = [\n\u001b[0;32m--> 305\u001b[0;31m         \u001b[0m_compute_mi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_feature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_target\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    306\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_feature\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_iterate_columns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiscrete_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    307\u001b[0m     ]\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/feature_selection/_mutual_info.py\u001b[0m in \u001b[0;36m_compute_mi\u001b[0;34m(x, y, x_discrete, y_discrete, n_neighbors)\u001b[0m\n\u001b[1;32m    164\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_compute_mi_cd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mx_discrete\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0my_discrete\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_compute_mi_cd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    167\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_compute_mi_cc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_neighbors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/feature_selection/_mutual_info.py\u001b[0m in \u001b[0;36m_compute_mi_cd\u001b[0;34m(c, d, n_neighbors)\u001b[0m\n\u001b[1;32m    139\u001b[0m     \u001b[0mradius\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mradius\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m     \u001b[0mkd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mKDTree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    142\u001b[0m     \u001b[0mm_all\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquery_radius\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mradius\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcount_only\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_distance\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    143\u001b[0m     \u001b[0mm_all\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm_all\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32msklearn/neighbors/_binary_tree.pxi\u001b[0m in \u001b[0;36msklearn.neighbors._kd_tree.BinaryTree.__init__\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/shims/versions/3.8.12/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[1;32m    965\u001b[0m         \u001b[0mn_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_num_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    966\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mn_samples\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mensure_min_samples\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 967\u001b[0;31m             raise ValueError(\n\u001b[0m\u001b[1;32m    968\u001b[0m                 \u001b[0;34m\"Found array with %d sample(s) (shape=%s) while a\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    969\u001b[0m                 \u001b[0;34m\" minimum of %d is required%s.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required."
     ]
    }
   ],
   "source": [
    "mutual_info_classif(ps[bs == 2], ms[bs == 2], discrete_features=[False, False],n_neighbors=1).sum()\n",
    "#ms[bs==2].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "cfbde620-8cb2-4e02-afd5-d06249c056c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#(ps[0] - ls[0])\n",
    "#torch.concat(loss).reshape(-1,2).shape\n",
    "#print(ps[0])\n",
    "#ps[0][ls[0] == 0] = 1 - ps[0][ls[0] == 0]\n",
    "#print(ps[0])\n",
    "#ls[1][0] - ps[1][0]\n",
    "#ps[0]\n",
    "#idx = torch.bucketize(ps[0], n_bins, right=True) - 1\n",
    "\n",
    "#ps[0][ls[0] == 0] = 1 - ps[0][ls[0] == 0]\n",
    "#l = (ls[i] - ps[i])\n",
    "ps = [p[2*0:2*0+2] for p in preds]\n",
    "len(preds)\n",
    "#len(preds[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "fffb6549-5863-430b-bd9d-811c7adef7a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#ps = [p[2*0:2*0+2] for p in preds]\n",
    "n_bins = torch.tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])\n",
    "#torch.bucketize(ps[0].softmax(1).max(1).values, n_bins, right=True) - 1\n",
    "#ps[0].softmax(1).max(1).values\n",
    "\n",
    "torch.bucketize(preds[0].softmax(1).max(1).values, n_bins, right=True) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "03f32731-816a-4790-93c2-397d7e72706f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## For softmax probability\n",
    "## Our estimation\n",
    "id = 50\n",
    "\n",
    "ms = [p[id] for p in masks] ## あるサンプルの mask (mask: select the train/val split (S)の際． train dataかval dataか)\n",
    "ps = [p[2*id:2*id+2] for p in preds] ## 異なる初期値から学習させたモデルのそれぞれに対し，　n-th sampleのsupersampleに対する予測確率 (e.g.; ３０個のモデルの，Z_{0,0},Z_{0,1}に対する予測確率)\n",
    "ls = [l[2*id:2*id+2] for l in labels]\n",
    "#loss = 0.\n",
    "#for i in range(len(ps)):\n",
    "#    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values\n",
    "    #ls[i] = (ls[i] - ps[i])**2\n",
    "    #ps[i] = torch.argmax(ps[i], dim=1) ## 各モデルごとに，　supersampleに対する予測確率の高い方のラベルに変換 (supersample間の予測ラベルを算出;サンプルの変更に頑健なら確率の変化は小さいのでラベルは一致(＝汎化))\n",
    "    #ls[i] = ls[i].item()\n",
    "\n",
    "#mutual_info_classif(ps.reshape(-1,2), ms, discrete_features=[True, True]).sum() ## discrete_features detects the dtype of \"X (= loss in this case)\" (There is no-problem on My MacBook Pro)\n",
    "#mutual_info_classif(torch.concat(ls).reshape(-1,2).numpy(), ms, discrete_features=[False, False]).sum()\n",
    "#mutual_info_classif(torch.concat(ps).reshape(-1,2).numpy(), ms, discrete_features=[False, False]).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 605,
   "id": "2935494a-c811-464f-95af-234914ddf50f",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[605], line 8\u001b[0m\n\u001b[1;32m      2\u001b[0m l \u001b[38;5;241m=\u001b[39m ls[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m      4\u001b[0m \u001b[38;5;66;03m# Apply argmax if we have one more dimension\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;66;03m#if p.ndim == l.ndim + 1 and convert_to_labels:\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;66;03m#    p = p.argmax(dim=1)\u001b[39;00m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;66;03m#p = p.flatten() if convert_to_labels else torch.movedim(p, 1, -1).reshape(-1, p.shape[1])\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m p \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmovedim\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, p\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m      9\u001b[0m l \u001b[38;5;241m=\u001b[39m l\u001b[38;5;241m.\u001b[39mflatten()\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m## 確率ベクトルへの変換．\u001b[39;00m\n",
      "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
     ]
    }
   ],
   "source": [
    "p = ps[0]\n",
    "l = ls[0]\n",
    "\n",
    "# Apply argmax if we have one more dimension\n",
    "#if p.ndim == l.ndim + 1 and convert_to_labels:\n",
    "#    p = p.argmax(dim=1)\n",
    "#p = p.flatten() if convert_to_labels else torch.movedim(p, 1, -1).reshape(-1, p.shape[1])\n",
    "p = torch.movedim(p, 1, -1).reshape(-1, p.shape[1])\n",
    "l = l.flatten()\n",
    "\n",
    "## 確率ベクトルへの変換．\n",
    "if not torch.all((p >= 0) * (p <= 1)):\n",
    "    p = p.softmax(1)\n",
    "confidences, predictions = p.max(dim=1)\n",
    "confidences[l==0] = 1 - confidences[l==0] ## MEMO: Reverse prob. for label y=0.\n",
    "accuracies = predictions.eq(l)\n",
    "confidences = confidences.float()\n",
    "accuracies = accuracies.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 606,
   "id": "0b6c9584-21e2-4204-adf2-8df5bd9bbdfc",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'confidences' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[606], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m n_bins \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m4\u001b[39m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#n_bins = torch.linspace(0, 1, n_bins + 1, dtype=confidences.dtype, device=confidences.device) ## Uniform binning TODO: implementing Uniform-mass binning\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m#n_bins[0], n_bins[-1] = 0., 1.\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m n_bins \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(np\u001b[38;5;241m.\u001b[39mquantile(\u001b[43mconfidences\u001b[49m, torch\u001b[38;5;241m.\u001b[39mlinspace(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, n_bins \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)))\n\u001b[1;32m      5\u001b[0m n_bins[\u001b[38;5;241m0\u001b[39m], n_bins[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.\u001b[39m, \u001b[38;5;241m1.\u001b[39m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;66;03m#acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries)\u001b[39;00m\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;66;03m## binごとのaccなどを算出\u001b[39;00m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'confidences' is not defined"
     ]
    }
   ],
   "source": [
    "n_bins = 4\n",
    "#n_bins = torch.linspace(0, 1, n_bins + 1, dtype=confidences.dtype, device=confidences.device) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "#n_bins[0], n_bins[-1] = 0., 1.\n",
    "n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, n_bins + 1)))\n",
    "n_bins[0], n_bins[-1] = 0., 1.\n",
    "\n",
    "with torch.no_grad():\n",
    "    #acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries)\n",
    "    ## binごとのaccなどを算出\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    acc_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    label_bin = torch.zeros(len(n_bins), device=l.device, dtype=l.dtype)\n",
    "\n",
    "    indices = torch.bucketize(confidences, n_bins, right=True) - 1\n",
    "    ## ここはrecalibration\n",
    "    bin_total = torch.bincount(indices, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "    bin_true = (torch.bincount(indices, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "    bin_true_label = (torch.bincount(indices, weights=l, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.filterwarnings('ignore')\n",
    "        # fill nan by interpolation assuming smoothness\n",
    "        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        bin_mean_l = interpolate_nan(bin_true_label.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        confidences = bin_mean[indices]\n",
    "        accuracies = bin_mean_l[indices]\n",
    "        \n",
    "    \n",
    "    count_bin.scatter_add_(dim=0, index=indices, src=torch.ones_like(confidences))\n",
    "    conf_bin.scatter_add_(dim=0, index=indices, src=confidences)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "    \n",
    "    label_bin.scatter_add_(dim=0, index=indices, src=l)\n",
    "    label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=indices, src=accuracies)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "b9a1cadf-8da6-4ff4-9662-dda4e3fbc15f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9983, 0.0000, 0.0000, 1.0000, 0.0000])"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "bd0ea292-f901-4756-985b-b3628d389961",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0009)\n",
      "tensor(0.0012)\n",
      "tensor(0.0009)\n",
      "tensor(0.0012)\n",
      "tensor(0.)\n",
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "## ECE\n",
    "print(torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin))\n",
    "#print(torch.max(torch.abs(acc_bin - conf_bin)))\n",
    "print(torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)))\n",
    "\n",
    "print(torch.sum(torch.abs(label_bin - conf_bin) * prop_bin))\n",
    "print(torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin)))\n",
    "print(torch.sum(torch.abs(label_bin - acc_bin) * prop_bin))\n",
    "print(torch.sqrt(torch.sum(torch.pow(label_bin - acc_bin, 2) * prop_bin)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "7484a4cb-7892-4eaf-a78c-ed796600e7ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "bin_total = torch.bincount(indices, minlength=len(n_bins)-1).to(confidences.device) ## the number of samples per bins\n",
    "#bin_true = (torch.bincount(indices, weights=l, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "bin_true = (torch.bincount(indices, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "2a6f52d5-148f-48bc-862b-e9850c238ce9",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'indices' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[75], line 29\u001b[0m\n\u001b[1;32m     27\u001b[0m     \u001b[38;5;66;03m# fill nan by interpolation assuming smoothness\u001b[39;00m\n\u001b[1;32m     28\u001b[0m     bin_mean \u001b[38;5;241m=\u001b[39m interpolate_nan(bin_true\u001b[38;5;241m.\u001b[39mnumpy() \u001b[38;5;241m/\u001b[39m bin_total\u001b[38;5;241m.\u001b[39mnumpy()) \u001b[38;5;66;03m## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\u001b[39;00m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28mprint\u001b[39m(bin_mean, bin_mean[\u001b[43mindices\u001b[49m])\n",
      "\u001b[0;31mNameError\u001b[0m: name 'indices' is not defined"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "\n",
    "def digitize(x, bins):\n",
    "    \"\"\"\n",
    "    Similar to np.digitize(x, bins, right=False) - 1, but the last bin unbounded on the right\n",
    "    e.g. bins = [0, 0.5, 1] -> bin 0: [0, 0.5), bin 1: [0.5, inf]\n",
    "    if the input is a number, return a number\n",
    "    \"\"\"\n",
    "    binids = np.minimum(np.digitize(x, bins), len(bins) - 1)\n",
    "    binids -= 1\n",
    "    if isinstance(x, (int, float)):\n",
    "        binids = binids.item()\n",
    "    return binids\n",
    "\n",
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b).float()\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.filterwarnings('ignore')\n",
    "    # fill nan by interpolation assuming smoothness\n",
    "    bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "print(bin_mean, bin_mean[indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "d50e0609-270e-4c64-90fd-0dcc160b0958",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "d86f4af6-23f4-488c-a4fa-ac917a0c5523",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.9983027  0.99886505 0.9994274  0.99998975] [0.99998975 0.9983027 ]\n"
     ]
    }
   ],
   "source": [
    "indices = digitize(confidences.numpy(), n_bins.numpy())\n",
    "bin_total = np.bincount(indices, minlength=len(n_bins)-1) ## the number of samples per bins\n",
    "#bin_true = np.bincount(indices, weights=l, minlength=len(n_bins)-1) ## the number of samples per bins weighted by labels\n",
    "bin_true = np.bincount(indices, weights=confidences, minlength=len(n_bins)-1) ## the number of samples per bins weighted by labels\n",
    "with warnings.catch_warnings():\n",
    "    warnings.filterwarnings('ignore')\n",
    "    # fill nan by interpolation assuming smoothness\n",
    "    bin_mean = interpolate_nan(bin_true / bin_total) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "#indices\n",
    "print(bin_mean, bin_mean[indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "bc9bfa52-6714-4874-83be-e0bbfd2e751e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0009)\n",
      "tensor(0.0017)\n",
      "tensor(0.0012)\n"
     ]
    }
   ],
   "source": [
    "## ECE\n",
    "print(torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin))\n",
    "print(torch.max(torch.abs(acc_bin - conf_bin)))\n",
    "print(torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1efa5a81-d3e8-452e-af99-ad4a15b4f5e5",
   "metadata": {},
   "source": [
    "## Recalibrate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4bb19a7b-707a-48b6-a5c3-3de896f1d70d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08b4082-f42b-4bdb-8898-7ed6df3150cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 612,
   "id": "d879d9e1-dc0f-4e44-bd7f-9a334f5e0591",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "test_idx = 2*np.arange(len(cur_mask)) + (1-cur_mask)\n",
    "p_train, p_test = cur_preds[train_idx], cur_preds[test_idx]\n",
    "targets, targets_test = cur_labels[train_idx], cur_labels[test_idx]\n",
    "\n",
    "#preds_cal, preds_test, targets_cal, targets_test = train_test_split(p_test, targets_test, test_size=0.8, random_state=42)\n",
    "\n",
    "## pred. on recalibrate dataset\n",
    "#p = torch.movedim(preds_cal, 1, -1).reshape(-1, preds_cal.shape[1])\n",
    "#l = targets_cal.flatten()\n",
    "#p = torch.movedim(cur_preds_recab, 1, -1).reshape(-1, cur_preds_recab.shape[1])\n",
    "#l = cur_labels_recab.flatten()\n",
    "p = torch.movedim(p_train, 1, -1).reshape(-1, p_train.shape[1])\n",
    "l = targets.flatten()\n",
    "\n",
    "if not torch.all((p >= 0) * (p <= 1)):\n",
    "    p = p.softmax(1)\n",
    "\n",
    "confidences, predictions = p.max(dim=1)\n",
    "confidences[l==0] = 1 - confidences[l==0] ## MEMO: Reverse prob. for label y=0.\n",
    "accuracies = predictions.eq(l)\n",
    "confidences = confidences.float()\n",
    "accuracies = accuracies.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 613,
   "id": "c7bd8a21-2352-4a4e-aa11-9e9ba22b44fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "## preds on train dataset\n",
    "p_train = torch.movedim(p_train, 1, -1).reshape(-1, p_train.shape[1])\n",
    "l_train = targets.flatten()\n",
    "\n",
    "if not torch.all((p_train >= 0) * (p_train <= 1)):\n",
    "    p_train = p_train.softmax(1)\n",
    "\n",
    "conf_train, pred_train = p_train.max(dim=1)\n",
    "conf_train[l_train==0] = 1 - conf_train[l_train==0] ## MEMO: Reverse prob. for label y=0.\n",
    "acc_train = pred_train.eq(l_train)\n",
    "conf_train = conf_train.float()\n",
    "acc_train = acc_train.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 614,
   "id": "b172c3f4-2627-4e58-9d85-121be19ea7d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "method = 'uniform'\n",
    "#method = 'quantile'\n",
    "\n",
    "n_bins = int(n ** (1/4))\n",
    "if method == 'uniform':\n",
    "    n_bins = torch.linspace(0, 1, n_bins + 1, dtype=confidences.dtype, device=confidences.device) ## Uniform binning TODO: implementing Uniform-mass binning\n",
    "    n_bins[0], n_bins[-1] = 0., 1.\n",
    "elif method == 'quantile':\n",
    "    n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, n_bins + 1)))\n",
    "    n_bins[0], n_bins[-1] = 0., 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 615,
   "id": "2b1a926b-8755-4778-9ba0-ef1b639b2335",
   "metadata": {},
   "outputs": [],
   "source": [
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    \n",
    "    return torch.tensor(b).float()\n",
    "\n",
    "with torch.no_grad():\n",
    "    accuracies = accuracies.to(dtype=confidences.dtype)\n",
    "    acc_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "    label_bin = torch.zeros(len(n_bins), device=l.device, dtype=l.dtype)\n",
    "    \n",
    "    indices = torch.bucketize(confidences, n_bins, right=True) - 1\n",
    "    bin_total = torch.bincount(indices, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "    bin_true = (torch.bincount(indices, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "    bin_true_label = (torch.bincount(indices, weights=l, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.filterwarnings('ignore')\n",
    "        # fill nan by interpolation assuming smoothness\n",
    "        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        bin_mean_l = interpolate_nan(bin_true_label.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "    \n",
    "    idx = torch.bucketize(conf_train, n_bins, right=True) - 1\n",
    "    conf_train = bin_mean[idx]\n",
    "    acc_train = bin_mean_l[idx]\n",
    "    \n",
    "    count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(conf_train))\n",
    "    conf_bin.scatter_add_(dim=0, index=idx, src=conf_train)\n",
    "    conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "    \n",
    "    label_bin.scatter_add_(dim=0, index=idx, src=l_train)\n",
    "    label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "\n",
    "    acc_bin.scatter_add_(dim=0, index=idx, src=acc_train)\n",
    "    acc_bin = torch.nan_to_num(acc_bin / count_bin)\n",
    "    prop_bin = count_bin / count_bin.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 616,
   "id": "4d135c63-8a66-4449-8872-b3d5ff80f5f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(6.6466e-06)\n",
      "tensor(6.7306e-06)\n",
      "tensor(0.)\n",
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "## ECE\n",
    "#print(torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin))\n",
    "#print(torch.max(torch.abs(acc_bin - conf_bin)))\n",
    "#print(torch.sqrt(torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin)))\n",
    "\n",
    "print(torch.sum(torch.abs(label_bin - conf_bin) * prop_bin))\n",
    "print(torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin)))\n",
    "\n",
    "print(torch.sum(torch.abs(label_bin - acc_bin) * prop_bin))\n",
    "print(torch.sqrt(torch.sum(torch.pow(label_bin - acc_bin, 2) * prop_bin)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 617,
   "id": "95ab84d2-be19-433f-91fd-cefcb41ae71d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/fujisawa/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/datasets/mnist.py:65: UserWarning: train_labels has been renamed targets\n",
      "  warnings.warn(\"train_labels has been renamed targets\")\n"
     ]
    }
   ],
   "source": [
    "#all_examples.dataset.dataset.dataset.train_data\n",
    "#all_examples.include_indices\n",
    "\n",
    "#len(all_examples.dataset.valid_indices)\n",
    "#all_examples.dataset.dataset.dataset.test_data[all_examples.dataset.dataset.dataset.test_labels == 4]\n",
    "a = torch.concat(torch.where(all_examples.dataset.dataset.dataset.train_labels == 4))\n",
    "b = torch.concat(torch.where(all_examples.dataset.dataset.dataset.train_labels == 9))\n",
    "test_label = torch.concat([a,b])\n",
    "\n",
    "#len(test_label)\n",
    "#test_label[all_examples.include_indices]\n",
    "#all_examples.dataset.dataset.dataset.train_data[all_examples.include_indices != test_label]\n",
    "#len(all_examples.dataset.dataset.dataset.train_data[all_examples.include_indices])\n",
    "#torch.sum(test_label == 2560)\n",
    "#test_label != all_examples.include_indices\n",
    "#test_49 = all_examples.dataset.dataset.dataset.test_data[test_label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c15588-234d-47b1-93c1-68c0733dc439",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 246,
   "id": "5c5a313d-df54-4a51-a8b7-56fb4f84a8cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 2787,  1316,  9658,  8185,   129,  8797, 11371,  2693,   162,\n",
       "         369, 10233, 11590,  1841,  4434,  2349, 10987, 11342, 11242,\n",
       "        1670,  8484,  5516,  3644,  6347,  8675,  7651,  5819,    15,\n",
       "        6186,  6408,  5076,   373,  1652,  7072,  7997, 11128,  6919,\n",
       "       10796,  4799,  8989,  5880,  6246,  8879, 10422, 10858,  4464,\n",
       "         660, 10715,  7705,  1514,  6917,  7464,  5132, 10037,   164,\n",
       "        3504,  8124, 10262,   578,  3193,  2496,  2236,  3026, 11271,\n",
       "       11409,  9235,  1738,  9514,  3458, 10646,  8880,  6194,  4791,\n",
       "        1802,  2637,  1884,    61,  7295, 11373,  8561,  8680,  6637,\n",
       "        6224,  8407,  8222,  3694,  2791,   911,  2991,  9270,  1361,\n",
       "        4332,  9681,  7471,  1626,  8444, 11692,  5616,  8664,  9088,\n",
       "         774])"
      ]
     },
     "execution_count": 246,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#np.random.randint()\n",
    "np.random.choice(len(test_49), 100, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 316,
   "id": "62676562-a8e9-4e31-9c15-cd26d7d174fd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7227,  843, 5799, 2560, 9104, 2724, 5507, 2861, 7170, 4826, 8127,\n",
       "       1237, 7852,  495, 2847, 2584, 4494, 5476, 2404, 1478, 9438, 5266,\n",
       "       1255,   36, 2513, 5634, 3735, 7608, 7741, 5135,  551, 4400,   33,\n",
       "       9065, 5948, 5325, 9334, 2613, 8496, 2443, 7082, 5753, 7428, 2571,\n",
       "       3925, 7032, 4277, 8705, 7268, 7709, 1852, 5838,  190, 6053, 5716,\n",
       "       5340,  888, 6777, 2604, 1662, 2608, 4523, 6601, 6966, 8821, 5654,\n",
       "       5114, 7240, 6868, 8812, 1525, 1813, 3305, 7051, 6427])"
      ]
     },
     "execution_count": 316,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#torch.tensor(all_examples.include_indices)\n",
    "#train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "#all_examples.include_indices[train_idx]\n",
    "#all_examples.dataset.dataset\n",
    "#len(test_49)\n",
    "#test_label\n",
    "all_examples.include_indices[saved_data[\"train_indices\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "id": "e765e3ce-493f-4ba4-8552-e5ac07f8bec2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([7227,  843, 5799, 2560, 9104, 2724, 5507, 2861, 7170, 4826, 8127,\n",
       "       1237, 7852,  495, 2847, 2584, 4494, 5476, 2404, 1478, 9438, 5266,\n",
       "       1255,   36, 2513, 5634, 3735, 7608, 7741, 5135,  551, 4400,   33,\n",
       "       9065, 5948, 5325, 9334, 2613, 8496, 2443, 7082, 5753, 7428, 2571,\n",
       "       3925, 7032, 4277, 8705, 7268, 7709, 1852, 5838,  190, 6053, 5716,\n",
       "       5340,  888, 6777, 2604, 1662, 2608, 4523, 6601, 6966, 8821, 5654,\n",
       "       5114, 7240, 6868, 8812, 1525, 1813, 3305, 7051, 6427])"
      ]
     },
     "execution_count": 309,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#all_examples.include_indices[saved_data[\"train_indices\"]]\n",
    "all_examples.include_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 398,
   "id": "92cb45bc-26c6-4490-966c-f2cd2c47e6cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.data.dataset.Subset at 0x29ec5ee20>"
      ]
     },
     "execution_count": 398,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#model.forward(all_examples.dataset.dataset.dataset.train_data[0],all_examples.dataset.dataset.dataset.train_labels[0])\n",
    "torch.utils.data.Subset(all_examples, range(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "eaa34f85-b198-46d7-95c3-40233cc31045",
   "metadata": {},
   "outputs": [],
   "source": [
    "#utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=n_batch)\n",
    "from nnlib.nnlib.data_utils.wrappers import SubsetDataWrapper, LabelSubsetWrapper, ResizeImagesWrapper\n",
    "from nnlib.nnlib.data_utils.base import get_loaders_from_datasets, get_input_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "b27657b3-8d6d-4c99-a197-f0be7b8d8421",
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = np.random.randint(2, size=(n,)) ## Ber(1/2)\n",
    "train_indices = 2*np.arange(n) + cur_mask #隣り合うsampleをtrain/valに分割（supersample）\n",
    "val_indices = 2*np.arange(n) + (1-cur_mask)\n",
    "train_data = SubsetDataWrapper(all_examples, include_indices=train_indices)\n",
    "val_data = SubsetDataWrapper(all_examples, include_indices=val_indices)\n",
    "num_workers = 0  # to make sure data shuffling is always done the same way\n",
    "\n",
    "#if args.deterministic:\n",
    "#    num_workers = 0  # to make sure data shuffling is always done the same way\n",
    "#else:\n",
    "#    num_workers = 2 # I changed this from 4 due to errors\n",
    "\n",
    "train_loader, val_loader, test_loader = get_loaders_from_datasets(train_data=train_data,\n",
    "                                                                     val_data=val_data,\n",
    "                                                                      test_data=None,\n",
    "                                                                      batch_size=256,\n",
    "                                                                      num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "e76e33ff-a015-4321-8377-3ec463b3d10e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Dataset mnist has no validation set. Splitting the training set...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "calling function __init__ of MNIST with the following parameters:\n",
      "\tdata_augmentation: False\n",
      "calling function build_datasets of MNIST with the following parameters:\n",
      "\tdata_dir: None\n",
      "\tval_ratio: 0.2\n",
      "\tnum_train_examples: None\n",
      "\tseed: 0\n",
      "\tdownload: True\n",
      "Dataset mnist is loaded\n",
      "\ttrain_samples: 48000\n",
      "\tval_samples: 12000\n",
      "\ttest_samples: 10000\n",
      "\texample_shape: torch.Size([1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "from modules.data_utils import load_data_from_arguments\n",
    "args = saved_data['args']\n",
    "examples,_,_,_ = load_data_from_arguments(args, build_loaders=False)\n",
    "examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "\n",
    "assert len(examples) >= 2 * args.n\n",
    "np.random.seed(args.seed)\n",
    "include_indices = np.random.choice(range(len(examples)), size=2 * args.n, replace=False)\n",
    "exclude_indices = np.arange(len(examples))[~np.isin(np.arange(len(examples)), include_indices)]\n",
    "exclude_indices = np.random.choice(exclude_indices, size=100, replace=False)\n",
    "examples = SubsetDataWrapper(examples, include_indices=exclude_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "id": "101f1d18-dc76-4683-bb50-47414f99e4b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150"
      ]
     },
     "execution_count": 164,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "d35410ab-d8a0-4e49-98e1-ab76344a2825",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(examples) >= 2 * args.n\n",
    "np.random.seed(args.seed)\n",
    "include_indices = np.random.choice(range(len(examples)), size=2 * args.n, replace=False)\n",
    "exclude_indices = np.arange(len(examples))[~np.isin(np.arange(len(examples)), include_indices)]\n",
    "exclude_indices = np.random.choice(exclude_indices, size=100, replace=False)\n",
    "examples = SubsetDataWrapper(examples, include_indices=exclude_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "eaee079d-5926-4f34-9de2-efaf695473a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 145,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "48abbfa5-3653-4799-9e32-ae783ec0dac7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function ndarray.all>"
      ]
     },
     "execution_count": 138,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(args.seed)\n",
    "include_indices = np.random.choice(range(len(examples)), size=2 * args.n, replace=False)\n",
    "#np.isin(examples.valid_indices, include_indices)\n",
    "#np.isin(examples.dataset.indices, include_indices).sum()\n",
    "examples.dataset.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "e302e93b-3512-4a6c-8d40-5dc3faa140be",
   "metadata": {},
   "outputs": [],
   "source": [
    "exclude_indices = np.arange(len(examples))[~np.isin(np.arange(len(examples)), include_indices)]\n",
    "exclude_indices = np.random.choice(exclude_indices, size=100, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "id": "fde03694-a623-41b0-99be-25aed7040e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "#utils.apply_on_dataset(model=model, dataset=all_examples, batch_size=args.batch_size)['pred'] ## (2*n, num_classes)\n",
    "#utils.apply_on_dataset(model=model, dataset=examples, batch_size=100)['pred'] ## (2*n, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 618,
   "id": "f242c436-85cf-46d7-86fa-126403bd8f92",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "conf = cur_preds[train_idx].softmax(1).max(1).values\n",
    "l = cur_labels[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 619,
   "id": "a617adfc-4ad8-4dc3-a125-4bff03fa503c",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf[l==0] = 1 - conf[l==0]\n",
    "n_bins = torch.tensor(np.quantile(conf, torch.linspace(0, 1, int(n**(1/3)) + 1)))\n",
    "n_bins[0], n_bins[-1] = 0., 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 620,
   "id": "5cc1d288-0cb9-4e62-afdc-c8e16627cf4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([5, 4, 4, 0, 3, 2, 5, 0, 1, 4, 0, 0, 3, 4, 2, 0, 4, 5, 1, 5, 4, 2, 0, 4,\n",
       "        4, 5, 2, 3, 3, 2, 3, 1, 3, 4, 5, 5, 3, 3, 3, 1, 2, 2, 4, 5, 5, 2, 4, 4,\n",
       "        1, 5, 3, 0, 1, 0, 0, 2, 2, 4, 3, 1, 0, 1, 0, 0, 1, 5, 1, 0, 4, 0, 3, 1,\n",
       "        4, 2, 1, 4, 2, 5, 4, 2, 6, 4, 0, 2, 5, 5, 1, 4, 3, 3, 3, 2, 0, 3, 1, 1,\n",
       "        1, 1, 0, 5, 5, 2, 2, 2, 3, 1, 4, 1, 4, 4, 4, 2, 1, 1, 5, 2, 0, 5, 3, 4,\n",
       "        5, 2, 3, 2, 1, 5, 1, 0, 4, 2, 6, 0, 1, 2, 3, 2, 3, 5, 1, 1, 3, 4, 5, 4,\n",
       "        0, 2, 3, 3, 3, 0, 2, 0, 0, 2, 3, 4, 0, 4, 3, 0, 0, 5, 2, 3, 2, 1, 2, 3,\n",
       "        6, 0, 5, 1, 0, 5, 3, 1, 6, 4, 5, 1, 0, 0, 4, 1, 2, 4, 0, 2, 3, 5, 5, 0,\n",
       "        4, 6, 0, 4, 1, 3, 3, 1, 3, 2, 2, 3, 5, 3, 0, 5, 0, 2, 4, 4, 1, 5, 5, 3,\n",
       "        3, 5, 1, 6, 4, 2, 1, 4, 2, 4, 1, 5, 4, 5, 2, 5, 4, 5, 1, 1, 2, 1, 2, 0,\n",
       "        0, 3, 2, 0, 3, 1, 3, 0, 2, 0])"
      ]
     },
     "execution_count": 620,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.bucketize(conf, n_bins, right=True) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 621,
   "id": "c6f95d2b-fe0c-423c-aa74-7fae090c25c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = 2*np.arange(len(cur_mask)) + cur_mask\n",
    "conf = cur_preds[train_idx].softmax(1).max(1).values\n",
    "l = cur_labels[train_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 622,
   "id": "3ffba059-0494-4696-bb9e-d88c7699eadb",
   "metadata": {},
   "outputs": [],
   "source": [
    "conf[l==0] = 1 - conf[l==0]\n",
    "n_bins = torch.linspace(0, 1, int(n**(1/3)) + 1)\n",
    "n_bins[0], n_bins[-1] = 0., 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 654,
   "id": "34691a87-4498-44be-bf08-ff6e74064bfe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.int64"
      ]
     },
     "execution_count": 654,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(torch.bucketize(conf, n_bins, right=True) - 1).dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 664,
   "id": "5c6c4395-525d-4aa6-b2a2-66e9543c0498",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([5, 5, 5, 0, 5, 0, 5, 0, 0, 5, 0, 0, 5, 5, 0, 0, 5, 5, 0, 5, 5, 0, 0, 5,\n",
       "        5, 5, 0, 5, 5, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 0, 0, 0, 5, 5, 5, 0, 5, 5,\n",
       "        0, 5, 5, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0, 5, 0,\n",
       "        5, 0, 0, 5, 0, 5, 5, 0, 5, 5, 0, 0, 5, 5, 0, 5, 5, 0, 5, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 5, 0, 5, 5, 5, 0, 0, 0, 5, 0, 0, 5, 5, 5,\n",
       "        5, 0, 5, 0, 0, 5, 0, 0, 5, 0, 5, 0, 0, 0, 5, 0, 5, 5, 0, 0, 5, 5, 5, 5,\n",
       "        0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 5, 5, 0, 5, 5, 0, 0, 5, 0, 5, 0, 0, 0, 5,\n",
       "        5, 0, 5, 0, 0, 5, 0, 0, 5, 5, 5, 0, 0, 0, 5, 0, 0, 5, 0, 0, 5, 5, 5, 0,\n",
       "        5, 5, 0, 5, 0, 5, 5, 0, 5, 0, 0, 0, 5, 0, 0, 5, 0, 0, 5, 5, 0, 5, 5, 0,\n",
       "        0, 5, 0, 5, 5, 0, 0, 5, 0, 5, 0, 5, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 5, 0, 5, 0, 0, 0])"
      ]
     },
     "execution_count": 664,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "binids = np.minimum(np.digitize(conf.numpy(), n_bins), len(n_bins) - 1)\n",
    "binids -= 1\n",
    "torch.tensor(binids)\n",
    "#isinstance(conf.numpy(), (int, float))\n",
    "#if isinstance(conf.numpy(), (int, float)):\n",
    "#    binids = binids.item()\n",
    "#torch.tensor(binids)\n",
    "#binids = np.minimum(np.digitize(x, bins), len(bins) - 1)\n",
    "#binids -= 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 642,
   "id": "41f47aef-e6df-4082-a3da-200db26ecb38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 5, 5, 0, 5, 0, 5, 0, 0, 5, 0, 0, 5, 5, 0, 0, 5, 5, 0, 5, 5, 0,\n",
       "       0, 5, 5, 5, 0, 5, 5, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 0, 0, 0, 5, 5,\n",
       "       5, 0, 5, 5, 0, 5, 5, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 5,\n",
       "       0, 0, 5, 0, 5, 0, 5, 0, 0, 5, 0, 5, 5, 0, 6, 5, 0, 0, 5, 5, 0, 5,\n",
       "       5, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0, 5, 0, 5, 5,\n",
       "       5, 0, 0, 0, 5, 0, 0, 5, 5, 5, 5, 0, 5, 0, 0, 5, 0, 0, 5, 0, 6, 0,\n",
       "       0, 0, 5, 0, 5, 5, 0, 0, 5, 5, 5, 5, 0, 0, 0, 5, 5, 0, 0, 0, 0, 0,\n",
       "       5, 5, 0, 5, 5, 0, 0, 5, 0, 5, 0, 0, 0, 5, 6, 0, 5, 0, 0, 5, 0, 0,\n",
       "       6, 5, 5, 0, 0, 0, 5, 0, 0, 5, 0, 0, 5, 5, 5, 0, 5, 6, 0, 5, 0, 5,\n",
       "       5, 0, 5, 0, 0, 0, 5, 0, 0, 5, 0, 0, 5, 5, 0, 5, 5, 0, 0, 5, 0, 6,\n",
       "       5, 0, 0, 5, 0, 5, 0, 5, 5, 5, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 5, 0, 5, 0, 0, 0])"
      ]
     },
     "execution_count": 642,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(torch.bucketize(conf, n_bins, right=True) - 1).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 665,
   "id": "2b51a9c6-abf2-40fc-8c7a-644a94f3beb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.1667, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000])"
      ]
     },
     "execution_count": 665,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "9a8f1da3-7e32-4528-8343-6c76b7a3b047",
   "metadata": {},
   "outputs": [],
   "source": [
    "ns = [75, 250, 1000, 4000]\n",
    "#ns = [500, 1000, 5000, 20000]\n",
    "opt_bin = np.floor((np.array(ns) ** (1/3)))\n",
    "#start, end, num_candidates = max(opt_bin)+5, 10**2, (10 - len(opt_bin))\n",
    "#int_candidates = np.ceil(np.logspace(np.log10(start), np.log10(end), num=num_candidates)).astype(int) ## bin candidates\n",
    "#int_candidates = np.concatenate([opt_bin, int_candidates]).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "2f033417-94a8-4095-a603-13a3837ae803",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "int"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt_bin\n",
    "a = np.floor((np.array(ns) ** (1/3))).astype(int)\n",
    "type(int(a[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 893,
   "id": "ff1ade6f-315b-4012-bb99-5b306175764d",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"results\"\n",
    "exp_name = \"fcmi-mnist-4vs9-CNN\"\n",
    "path = os.path.join(os.getcwd(), results_dir, exp_name)\n",
    "if not os.path.exists(os.path.join(path, \"results_n_vs_b\")):\n",
    "    os.mkdir(os.path.join(path, \"results_n_vs_b\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "f4f17f86-2baf-493b-987c-fbcff01b2ad3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "int(ns[0] ** (1/3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "62a494fa-eb4b-43ad-9f57-09e17528c622",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from results/cifar10-pretrained-resnet50/n=20000,lr=0.01,seed=0,S_seed=10/checkpoints/epoch39.mdl\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/fujisawa/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/Users/fujisawa/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output.shape: [None, 10]\n"
     ]
    }
   ],
   "source": [
    "#n = 500\n",
    "n = 20000\n",
    "lr = 0.01\n",
    "seed = 0\n",
    "S_seed = 10\n",
    "results_dir = \"results\"\n",
    "exp_name = \"cifar10-pretrained-resnet50\"\n",
    "epoch = 40\n",
    "\n",
    "dir_name = f'n={n},lr={lr},seed={seed},S_seed={S_seed}'\n",
    "dir_path = os.path.join(results_dir, exp_name, dir_name)\n",
    "#if not os.path.exists(dir_path):\n",
    "#    print(f\"Did not find results for {dir_name}\")\n",
    "#    continue\n",
    "\n",
    "with open(os.path.join(dir_path, 'saved_data.pkl'), 'rb') as f:\n",
    "    saved_data = pickle.load(f)\n",
    "\n",
    "model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{epoch - 1}.mdl'),\n",
    "                           methods=methods, device='mps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5e14fa0e-4422-4566-b5eb-a026b4ac6948",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'all_examples' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m cur_preds \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mapply_on_dataset(model\u001b[38;5;241m=\u001b[39mmodel, dataset\u001b[38;5;241m=\u001b[39m\u001b[43mall_examples\u001b[49m,\n\u001b[1;32m      2\u001b[0m                                            batch_size\u001b[38;5;241m=\u001b[39margs\u001b[38;5;241m.\u001b[39mbatch_size)[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m      4\u001b[0m cur_labels \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mapply_on_dataset(model\u001b[38;5;241m=\u001b[39mmodel, dataset\u001b[38;5;241m=\u001b[39mall_examples, \n\u001b[1;32m      5\u001b[0m                                             batch_size\u001b[38;5;241m=\u001b[39margs\u001b[38;5;241m.\u001b[39mbatch_size)[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel_0\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m      6\u001b[0m cur_mask \u001b[38;5;241m=\u001b[39m saved_data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmask\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'all_examples' is not defined"
     ]
    }
   ],
   "source": [
    "cur_preds = utils.apply_on_dataset(model=model, dataset=all_examples,\n",
    "                                           batch_size=args.batch_size)['pred']\n",
    "        \n",
    "cur_labels = utils.apply_on_dataset(model=model, dataset=all_examples, \n",
    "                                            batch_size=args.batch_size)['label_0']\n",
    "cur_mask = saved_data['mask']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "130bc660-6738-4250-a2a3-80b7cff8f251",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_all_data(saved_data, test_data=False):\n",
    "    args = saved_data['args']\n",
    "    if test_data:\n",
    "        _, _, examples, _ = load_data_from_arguments(args, build_loaders=False)\n",
    "    else:\n",
    "        examples, _, _, _ = load_data_from_arguments(args, build_loaders=False)\n",
    "    \n",
    "    # select labels if needed\n",
    "    if args.which_labels is not None:\n",
    "        examples = LabelSubsetWrapper(examples, which_labels=args.which_labels)\n",
    "    \n",
    "    # resize if needed\n",
    "    if args.resize_to_imagenet:\n",
    "        examples = ResizeImagesWrapper(examples, size=(224, 224))\n",
    "    \n",
    "    return examples\n",
    "    \n",
    "def preds_recalibrate(saved_data, all_data, n_data=100, test_data=False):\n",
    "    all_examples = saved_data['all_examples'] ## train, val data\n",
    "    args = saved_data['args']\n",
    "    \n",
    "    assert len(all_data) >= n_data\n",
    "    np.random.seed(args.seed)\n",
    "    if test_data:\n",
    "        exclude_indices = np.arange(len(all_data))\n",
    "        if n_data >= len(exclude_indices):\n",
    "            raise ValueError(f\" The number of recalibration data is larger than that of all data\")\n",
    "        exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)\n",
    "    else:\n",
    "        include_indices = all_examples.include_indices\n",
    "        exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]\n",
    "        if n_data >= len(exclude_indices):\n",
    "            raise ValueError(f\" The number of recalibration data is larger than that of all data\")\n",
    "        exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)\n",
    "    \n",
    "    examples = SubsetDataWrapper(all_data, include_indices=exclude_indices)\n",
    "    \n",
    "    return examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "035ef6b4-0b10-458f-9404-eccc90be3c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "#examples , val_data, test_data, _ = load_data_from_arguments(args, build_loaders=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "c2466521-5d0e-495f-9595-944d4ee76002",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset CIFAR10\n",
       "    Number of datapoints: 50000\n",
       "    Root location: cifar10\n",
       "    Split: Train\n",
       "    StandardTransform\n",
       "Transform: Compose(\n",
       "               ToTensor()\n",
       "               Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))\n",
       "           )"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "061e0225-7d0c-4c46-b731-54851074d7b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "calling function __init__ of CIFAR with the following parameters:\n",
      "\tn_classes: 10\n",
      "\tdata_augmentation: False\n",
      "calling function build_datasets of CIFAR with the following parameters:\n",
      "\tdata_dir: None\n",
      "\tval_ratio: 0.2\n",
      "\tnum_train_examples: None\n",
      "\tseed: 0\n",
      "\tdownload: True\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Dataset cifar10 has no validation set. Splitting the training set...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset cifar10 is loaded\n",
      "\ttrain_samples: 40000\n",
      "\tval_samples: 10000\n",
      "\ttest_samples: 10000\n",
      "\texample_shape: torch.Size([3, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "all_data = load_all_data(saved_data, test_data=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "9ccaba56-649a-44ca-a1ae-0c0a1f5f9d70",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▊                                         | 1/16 [00:06<01:44,  6.96s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[148], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m examples \u001b[38;5;241m=\u001b[39m preds_recalibrate(saved_data, all_data, n_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, test_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 2\u001b[0m cur_preds_recab \u001b[38;5;241m=\u001b[39m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply_on_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexamples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;66;03m## (2*n, num_classes)\u001b[39;00m\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/utils.py:127\u001b[0m, in \u001b[0;36mwith_no_grad.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    125\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    126\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 127\u001b[0m         ret \u001b[38;5;241m=\u001b[39m \u001b[43minit_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    128\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ret\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/utils.py:155\u001b[0m, in \u001b[0;36mapply_on_dataset\u001b[0;34m(model, dataset, batch_size, cpu, description, output_keys_regexp, max_num_examples, num_workers, **kwargs)\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(labels_batch, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m    154\u001b[0m     labels_batch \u001b[38;5;241m=\u001b[39m [labels_batch]\n\u001b[0;32m--> 155\u001b[0m outputs_batch \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    156\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m outputs_batch\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    157\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m re\u001b[38;5;241m.\u001b[39mfullmatch(output_keys_regexp, k) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/methods/classifiers.py:35\u001b[0m, in \u001b[0;36mStandardClassifier.forward\u001b[0;34m(self, inputs, labels, grad_enabled, detailed_output, **kwargs)\u001b[0m\n\u001b[1;32m     31\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(grad_enabled)\n\u001b[1;32m     33\u001b[0m x \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m---> 35\u001b[0m details \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassifier\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     36\u001b[0m pred \u001b[38;5;241m=\u001b[39m details\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     38\u001b[0m out \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m     39\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m: pred\n\u001b[1;32m     40\u001b[0m }\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/nn_utils.py:251\u001b[0m, in \u001b[0;36mStandardNetworkWrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 251\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    252\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m    253\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_name: ret\n\u001b[1;32m    254\u001b[0m     }\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:285\u001b[0m, in \u001b[0;36mResNet.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    284\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 285\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:273\u001b[0m, in \u001b[0;36mResNet._forward_impl\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    270\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(x)\n\u001b[1;32m    271\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmaxpool(x)\n\u001b[0;32m--> 273\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    274\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer2(x)\n\u001b[1;32m    275\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer3(x)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/container.py:215\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m    214\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 215\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:150\u001b[0m, in \u001b[0;36mBottleneck.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbn1(out)\n\u001b[1;32m    148\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n\u001b[0;32m--> 150\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    151\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbn2(out)\n\u001b[1;32m    152\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m    452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    453\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m    454\u001b[0m                     weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m    455\u001b[0m                     _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    457\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "examples = preds_recalibrate(saved_data, all_data, n_data=1000, test_data=True)\n",
    "#cur_preds_recab = utils.apply_on_dataset(model=model, dataset=examples, batch_size=64)['pred'] ## (2*n, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "8cba43b0-30cf-46f3-bd00-3538571bec24",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data = 1000\n",
    "all_examples = saved_data['all_examples'] ## train, val data\n",
    "args = saved_data['args']\n",
    "\n",
    "np.random.seed(args.seed)\n",
    "exclude_indices = np.arange(len(all_data))\n",
    "#include_indices = all_examples.include_indices\n",
    "#exclude_indices = np.arange(len(all_data))[~np.isin(np.arange(len(all_data)), include_indices)]\n",
    "#exclude_indices = np.arange(len(all_examples.dataset.dataset.dataset.data))[~np.isin(np.arange(len(all_examples.dataset.dataset.dataset.data)), include_indices)]\n",
    "exclude_indices = np.random.choice(exclude_indices, size=n_data, replace=False)\n",
    "\n",
    "examples = SubsetDataWrapper(all_data, include_indices=exclude_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f49231ed-596e-43dd-8527-36963426a299",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "4b0a33f6-d0ae-4ce7-8736-4ccf5786b39c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading the model from results/cifar10-pretrained-resnet50/n=20000,lr=0.01,seed=0,S_seed=10/checkpoints/epoch39.mdl\n",
      "output.shape: [None, 10]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|██▊                                         | 1/16 [00:09<02:18,  9.24s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[141], line 5\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmethods\u001b[39;00m\n\u001b[1;32m      3\u001b[0m model \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mload(path\u001b[38;5;241m=\u001b[39mos\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(dir_path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcheckpoints\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;241m40\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.mdl\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m      4\u001b[0m                            methods\u001b[38;5;241m=\u001b[39mmethods, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m cur_preds_recab \u001b[38;5;241m=\u001b[39m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply_on_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexamples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;66;03m## (2*n, num_classes)\u001b[39;00m\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/utils.py:127\u001b[0m, in \u001b[0;36mwith_no_grad.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    125\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    126\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 127\u001b[0m         ret \u001b[38;5;241m=\u001b[39m \u001b[43minit_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    128\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ret\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/utils.py:155\u001b[0m, in \u001b[0;36mapply_on_dataset\u001b[0;34m(model, dataset, batch_size, cpu, description, output_keys_regexp, max_num_examples, num_workers, **kwargs)\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(labels_batch, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m    154\u001b[0m     labels_batch \u001b[38;5;241m=\u001b[39m [labels_batch]\n\u001b[0;32m--> 155\u001b[0m outputs_batch \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    156\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m outputs_batch\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    157\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m re\u001b[38;5;241m.\u001b[39mfullmatch(output_keys_regexp, k) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/methods/classifiers.py:35\u001b[0m, in \u001b[0;36mStandardClassifier.forward\u001b[0;34m(self, inputs, labels, grad_enabled, detailed_output, **kwargs)\u001b[0m\n\u001b[1;32m     31\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(grad_enabled)\n\u001b[1;32m     33\u001b[0m x \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m---> 35\u001b[0m details \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassifier\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     36\u001b[0m pred \u001b[38;5;241m=\u001b[39m details\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     38\u001b[0m out \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m     39\u001b[0m     \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpred\u001b[39m\u001b[38;5;124m'\u001b[39m: pred\n\u001b[1;32m     40\u001b[0m }\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/Dropbox/research/Gen_bound_calibration/a_new_family_of_generalization_bounds_supp/klbound/f-CMI/nnlib/nnlib/nn_utils.py:251\u001b[0m, in \u001b[0;36mStandardNetworkWrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 251\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    252\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m    253\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_name: ret\n\u001b[1;32m    254\u001b[0m     }\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:285\u001b[0m, in \u001b[0;36mResNet.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    284\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 285\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:275\u001b[0m, in \u001b[0;36mResNet._forward_impl\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    273\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer1(x)\n\u001b[1;32m    274\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer2(x)\n\u001b[0;32m--> 275\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    276\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer4(x)\n\u001b[1;32m    278\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mavgpool(x)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/container.py:215\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m    214\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 215\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    216\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torchvision/models/resnet.py:154\u001b[0m, in \u001b[0;36mBottleneck.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    151\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbn2(out)\n\u001b[1;32m    152\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n\u001b[0;32m--> 154\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    155\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbn3(out)\n\u001b[1;32m    157\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownsample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 460\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/nn/modules/conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m    452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    453\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m    454\u001b[0m                     weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m    455\u001b[0m                     _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    457\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import methods\n",
    "\n",
    "model = utils.load(path=os.path.join(dir_path, 'checkpoints', f'epoch{40 - 1}.mdl'),\n",
    "                           methods=methods, device='cpu')\n",
    "cur_preds_recab = utils.apply_on_dataset(model=model, dataset=examples, batch_size=64)['pred'] ## (2*n, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "04ef5969-4f29-4b89-8c67-fd54e5fbe580",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([40000, 40001, 40002, ..., 49997, 49998, 49999])"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.arange(len(all_examples.dataset.dataset.dataset.data))[~np.isin(np.arange(len(all_examples.dataset.dataset.dataset.data)), include_indices)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "7943f84d-f011-4bfd-89e2-e6768bb703b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40000"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#len(all_examples.dataset.dataset.dataset.data)\n",
    "#np.arange(len(all_data))\n",
    "len(all_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "56c4c6ec-3081-408f-b008-638f31d60eb1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([49394, 40898, 42398, 45906, 42343, 48225, 45506, 46451, 42670,\n",
       "       43497, 41087, 41819, 42308, 46084, 43724, 43184, 46387, 43728,\n",
       "       42702, 47883, 42930, 45988, 44890, 46718, 45423, 43213, 43017,\n",
       "       40382, 44237, 44721, 49547, 49477, 44795, 44747, 49366, 45334,\n",
       "       46652, 49032, 40580, 49491, 46526, 44346, 44974, 47913, 45611,\n",
       "       48480, 46625, 45615, 45602, 44857, 46734, 48451, 46332, 46798,\n",
       "       45313, 42821, 49300, 42375, 41478, 45013, 41559, 48885, 43986,\n",
       "       44429, 43951, 42932, 46419, 40713, 48089, 46058, 48711, 44185,\n",
       "       44379, 47813, 40843, 40655, 47219, 40982, 47439, 48894, 40598,\n",
       "       45350, 44221, 41422, 43646, 44387, 42606, 44228, 41096, 48940,\n",
       "       45814, 43706, 42846, 40467, 48314, 44595, 43725, 46450, 46846,\n",
       "       43891, 47407, 45892, 45301, 44808, 48148, 43931, 49672, 41312,\n",
       "       41942, 40841, 45792, 43683, 48876, 45660, 44447, 43066, 48135,\n",
       "       45957, 44678, 42869, 42535, 45482, 44302, 45608, 47527, 41534,\n",
       "       48110, 46699, 44696, 42580, 47178, 44644, 44858, 45946, 49822,\n",
       "       44673, 43341, 41376, 49320, 46275, 41308, 47569, 47358, 42252,\n",
       "       49892, 48830, 47980, 48850, 49345, 49276, 40521, 43232, 40226,\n",
       "       43953, 41427, 48748, 48060, 45525, 40167, 41103, 49089, 46759,\n",
       "       47648, 45473, 42376, 40296, 41785, 47684, 47174, 44700, 45657,\n",
       "       46492, 47978, 41085, 45888, 49715, 48057, 41402, 49827, 43977,\n",
       "       43652, 44456, 44912, 42579, 42288, 48301, 42090, 45818, 42929,\n",
       "       40256, 43838, 46049, 49862, 44100, 47977, 43039, 40334, 47053,\n",
       "       49617, 46939, 46768, 40779, 44277, 41507, 43776, 42206, 46836,\n",
       "       48461, 44430, 47923, 48249, 48447, 46061, 44643, 43546, 47176,\n",
       "       45141, 41261, 42128, 49748, 42849, 44168, 45355, 40992, 45498,\n",
       "       46714, 42054, 49333, 41536, 49071, 40487, 49228, 47325, 46005,\n",
       "       47230, 47235, 42890, 40578, 40651, 40272, 48083, 47810, 41997,\n",
       "       47049, 42436, 42006, 46877, 44608, 47481, 41946, 47415, 41201,\n",
       "       43813, 44033, 47473, 43667, 41730, 48709, 45567, 46837, 43812,\n",
       "       46740, 44676, 41283, 41455, 45322, 49792, 40190, 44645, 40565,\n",
       "       42501, 49368, 40048, 41611, 49482, 46821, 42043, 40377, 44413,\n",
       "       45607, 44087, 46755, 41495, 42081, 47612, 42201, 41382, 48960,\n",
       "       44737, 46948, 43086, 44339, 48339, 46152, 45461, 47805, 45228,\n",
       "       42685, 45725, 48051, 47655, 48544, 44007, 48195, 46985, 49698,\n",
       "       46488, 42587, 48221, 44534, 45416, 42284, 43358, 40834, 45667,\n",
       "       48046, 43169, 48018, 40555, 40265, 43637, 41109, 47346, 47457,\n",
       "       40159, 41106, 40230, 47467, 46642, 49802, 43071, 49543, 47070,\n",
       "       49520, 44813, 47854, 42113, 42585, 41824, 41813, 48310, 44520,\n",
       "       44568, 47037, 43585, 42007, 48849, 40125, 41149, 49098, 41121,\n",
       "       40247, 49371, 42092, 49534, 45235, 43274, 46552, 49970, 40693,\n",
       "       45367, 45053, 49298, 40499, 47398, 46575, 45777, 48726, 45741,\n",
       "       42183, 41176, 45529, 41029, 44945, 46425, 43932, 43181, 40528,\n",
       "       43850, 46665, 49416, 43222, 42017, 49151, 43402, 47509, 46412,\n",
       "       44752, 46561, 41748, 46591, 43147, 42036, 40009, 45121, 43374,\n",
       "       46027, 45523, 45505, 48362, 42817, 43114, 42734, 47892, 41132,\n",
       "       42268, 42506, 40201, 48330, 48185, 45891, 47945, 40221, 40454,\n",
       "       40242, 42892, 47731, 44254, 46712, 43828, 48655, 41095, 48006,\n",
       "       40042, 45868, 40154, 46694, 48132, 46685, 41379, 46041, 44757,\n",
       "       48072, 44546, 49693, 41593, 49262, 47699, 40176, 49334, 43696,\n",
       "       43140, 43079, 43981, 48276, 48845, 47602, 46968, 43118, 42613,\n",
       "       46825, 46748, 42434, 47450, 41645, 40189, 44135, 43510, 49054,\n",
       "       49983, 44214, 48789, 41644, 48034, 45186, 44001, 43100, 45686,\n",
       "       48128, 48259, 42411, 47062, 47577, 48662, 46554, 43368, 40597,\n",
       "       42476, 48204, 43790, 45034, 44501, 47986, 46046, 45399, 48734,\n",
       "       40994, 45562, 46376, 42784, 41468, 45232, 43065, 44490, 42500,\n",
       "       48938, 49291, 40491, 42021, 44299, 40799, 41069, 48410, 49436,\n",
       "       45099, 42227, 40462, 45794, 47359, 49213, 43566, 44539, 40836,\n",
       "       48562, 41761, 42948, 44558, 46353, 49750, 49474, 44848, 43030,\n",
       "       48899, 43842, 43456, 41835, 46721, 48754, 48643, 42193, 49083,\n",
       "       42639, 42031, 47188, 43856, 40119, 42620, 44855, 40424, 49517,\n",
       "       44055, 47722, 43990, 40320, 48724, 49204, 44653, 42754, 47610,\n",
       "       46756, 42972, 44392, 48117, 49230, 47991, 49335, 41454, 41167,\n",
       "       45561, 43023, 45010, 43594, 46902, 48098, 46883, 44639, 40036,\n",
       "       49691, 49980, 49616, 48750, 40920, 49946, 44166, 40121, 49742,\n",
       "       48488, 45425, 45447, 42124, 41351, 48780, 43431, 47149, 40692,\n",
       "       42032, 42431, 41297, 46706, 49796, 42569, 42530, 48187, 41654,\n",
       "       49569, 44492, 43636, 45221, 47080, 41514, 40150, 42774, 44064,\n",
       "       45122, 47982, 49133, 46277, 40888, 46810, 47556, 44310, 42797,\n",
       "       49538, 47688, 42872, 44634, 45134, 42550, 44959, 48412, 45764,\n",
       "       41766, 47297, 47432, 44964, 49557, 46651, 40044, 43944, 40181,\n",
       "       45941, 47499, 47857, 46727, 47065, 43051, 44298, 46520, 47041,\n",
       "       41373, 48497, 48214, 49189, 49934, 40449, 41156, 42137, 46667,\n",
       "       47780, 47061, 40140, 42448, 48251, 44960, 49459, 48302, 46020,\n",
       "       45680, 46848, 41220, 40665, 43253, 45622, 45587, 46272, 47841,\n",
       "       46671, 46125, 43689, 49172, 44581, 48123, 40915, 47479, 46426,\n",
       "       41845, 43019, 48607, 44406, 46794, 49006, 41328, 44284, 40551,\n",
       "       48243, 47471, 47530, 42359, 47001, 47848, 42048, 41466, 48815,\n",
       "       46434, 45286, 42936, 47562, 45916, 41746, 43883, 41519, 45938,\n",
       "       44399, 45474, 42503, 45886, 45664, 43954, 45167, 47735, 40541,\n",
       "       41359, 41985, 48213, 49495, 49036, 46880, 49641, 45747, 40253,\n",
       "       40660, 46828, 42241, 48691, 47201, 47976, 48594, 40974, 48941,\n",
       "       46963, 44306, 48205, 49073, 45922, 47411, 49268, 45018, 45771,\n",
       "       48939, 46924, 45398, 44127, 43297, 41516, 46101, 41035, 49678,\n",
       "       40033, 46268, 42933, 48570, 40683, 47214, 43705, 40244, 49391,\n",
       "       47057, 44532, 44579, 41235, 43858, 41238, 46138, 48434, 41418,\n",
       "       46970, 40704, 43443, 44138, 46129, 40574, 48130, 49203, 44659,\n",
       "       46887, 44063, 49585, 49506, 40234, 46984, 41467, 45352, 42162,\n",
       "       45371, 44601, 40495, 40572, 46726, 45815, 46455, 41102, 41145,\n",
       "       46009, 44156, 45510, 44343, 45806, 44412, 44029, 41091, 43925,\n",
       "       43524, 49630, 43857, 42499, 42250, 42134, 45579, 44024, 45873,\n",
       "       44251, 48746, 49338, 46452, 41098, 47455, 44846, 40390, 47032,\n",
       "       48078, 43750, 42009, 42760, 41662, 48505, 44398, 47422, 45366,\n",
       "       47105, 44753, 40170, 45218, 49382, 43553, 42826, 42167, 42919,\n",
       "       44580, 41240, 43482, 44085, 47721, 49146, 41237, 47925, 43315,\n",
       "       43992, 48052, 43125, 40224, 40735, 48973, 41577, 49798, 40671,\n",
       "       41363, 46368, 46214, 44226, 49896, 42509, 49026, 47947, 48178,\n",
       "       49380, 40039, 47858, 44191, 46284, 49305, 40031, 40722, 44780,\n",
       "       47953, 44022, 45485, 48482, 48805, 40415, 40126, 43789, 47195,\n",
       "       42958, 41461, 44894, 45188, 48278, 41294, 47327, 48633, 45093,\n",
       "       48955, 47242, 45742, 44892, 46443, 48668, 48265, 41669, 46920,\n",
       "       43152, 49283, 42898, 45925, 47497, 46678, 48690, 46648, 49971,\n",
       "       46632, 47092, 47309, 44665, 45926, 41831, 40410, 49870, 49964,\n",
       "       42049, 47256, 42749, 43273, 48682, 43841, 46411, 40476, 41872,\n",
       "       41070, 45867, 41270, 45855, 44290, 42406, 45540, 43855, 48398,\n",
       "       49582, 40839, 40023, 48605, 47590, 49536, 48838, 46078, 46929,\n",
       "       41597, 48005, 43545, 42470, 47662, 45341, 46882, 44074, 48972,\n",
       "       46178, 47859, 45326, 45838, 49685, 49674, 46092, 45861, 44575,\n",
       "       40369, 47583, 41861, 43657, 41286, 44687, 45559, 46065, 42189,\n",
       "       45654, 49759, 41533, 47971, 46109, 46224, 40883, 48001, 43557,\n",
       "       42060, 47631, 43955, 49730, 41453, 43170, 48659, 47241, 48865,\n",
       "       45069, 48986, 48920, 40277, 41640, 49319, 42662, 46925, 48070,\n",
       "       43651])"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7ac5f32-a0ac-4060-be6f-05f374fe6a6e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
