{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "27e00f1b-9c6d-4519-a00f-fda6a99cbd3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import xgboost as xgb\n",
    "import sklearn\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "import sklearn.model_selection as ms\n",
    "import numpy as np\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "from nnlib.nnlib import utils\n",
    "\n",
    "import warnings\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import h5py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de76dad9-a84b-447b-a5e5-e330cbf052fe",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "e901f7c4-69af-4c1c-bc3a-2ff52ee9a127",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = ['kitti_all_train.data', 'kitti_all_train.labels', 'kitti_all_test.data', 'kitti_all_test.labels']\n",
    "file_path = os.getcwd() + '/kitti_features/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "f7800657-eac1-466f-8b53-f2e75df3ac66",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.loadtxt(os.path.join(file_path, files[0]), np.float64, skiprows=1)\n",
    "y_train = np.loadtxt(os.path.join(file_path, files[1]), np.int32, skiprows=1)\n",
    "X_test = np.loadtxt(os.path.join(file_path, files[2]), np.float64, skiprows=1)\n",
    "y_test = np.loadtxt(os.path.join(file_path, files[3]), np.int32, skiprows=1)\n",
    "\n",
    "y_train = np.where(y_train > 0, 1, 0)\n",
    "y_test = np.where(y_test > 0, 1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "4de3af5a-5d73-4e62-ad8f-ae106e4f022f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16000"
      ]
     },
     "execution_count": 144,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c389ec36-8363-4e50-9cdd-325a8447eec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "start, end, num_candidates = 6, 10**3, 10\n",
    "int_candidates = np.ceil(np.logspace(np.log10(start), np.log10(end), num=num_candidates)).astype(int) ## bin candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5fc26407-bbfb-4f72-90d1-6523850f3da4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   6,   11,   19,   34,   59,  103,  182,  321,  567, 1000])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "int_candidates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4028604e-de30-47d4-9a19-5b03d8879fc6",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "24103c77-8ca0-4c3c-9ffe-ab27a1dd272d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Classify KITTI validation data with selected classifiers\n",
    "random_state = 1\n",
    "model = xgb.XGBClassifier(booster=\"gbtree\", n_estimators=100, random_state=random_state, n_jobs=-1)\n",
    "#model = RandomForestClassifier(n_estimators=100, criterion=\"gini\", min_samples_split=2, bootstrap=True, n_jobs=-1, random_state=random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "90b02e79-9602-405b-803c-493491baed23",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBClassifier(base_score=None, booster=&#x27;gbtree&#x27;, callbacks=None,\n",
       "              colsample_bylevel=None, colsample_bynode=None,\n",
       "              colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "              enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "              gamma=None, grow_policy=None, importance_type=None,\n",
       "              interaction_constraints=None, learning_rate=None, max_bin=None,\n",
       "              max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "              max_delta_step=None, max_depth=None, max_leaves=None,\n",
       "              min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "              multi_strategy=None, n_estimators=100, n_jobs=-1,\n",
       "              num_parallel_tree=None, random_state=1, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=&#x27;gbtree&#x27;, callbacks=None,\n",
       "              colsample_bylevel=None, colsample_bynode=None,\n",
       "              colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "              enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "              gamma=None, grow_policy=None, importance_type=None,\n",
       "              interaction_constraints=None, learning_rate=None, max_bin=None,\n",
       "              max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "              max_delta_step=None, max_depth=None, max_leaves=None,\n",
       "              min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "              multi_strategy=None, n_estimators=100, n_jobs=-1,\n",
       "              num_parallel_tree=None, random_state=1, ...)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "XGBClassifier(base_score=None, booster='gbtree', callbacks=None,\n",
       "              colsample_bylevel=None, colsample_bynode=None,\n",
       "              colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "              enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "              gamma=None, grow_policy=None, importance_type=None,\n",
       "              interaction_constraints=None, learning_rate=None, max_bin=None,\n",
       "              max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "              max_delta_step=None, max_depth=None, max_leaves=None,\n",
       "              min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "              multi_strategy=None, n_estimators=100, n_jobs=-1,\n",
       "              num_parallel_tree=None, random_state=1, ...)"
      ]
     },
     "execution_count": 146,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X=X_train, y=y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c3cff4db-afb8-44a0-b932-192c2a9a1c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_pred_train = model.predict_proba(X_train)\n",
    "p_pred = model.predict_proba(X_test)\n",
    "y_y_pred = np.column_stack([y_test, np.argmax(p_pred, axis=1)])\n",
    "# Remove possible nan or inf\n",
    "isnotnan_ind = ~np.isnan(p_pred).any(axis=1)\n",
    "p_pred = p_pred[isnotnan_ind, :]\n",
    "y_y_pred = y_y_pred[isnotnan_ind, :]\n",
    "\n",
    "isnotnan_ind = ~np.isnan(p_pred_train).any(axis=1)\n",
    "p_pred_train = p_pred_train[isnotnan_ind, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9fbe06cc-bfed-42f9-97eb-88faf399ccb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_acc(preds, labels):\n",
    "    preds = torch.tensor(p_pred)\n",
    "    labels = torch.tensor(labels).long()\n",
    "    acc = (preds.argmax(dim=1) == labels).float().mean()\n",
    "    return utils.to_numpy(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e1e9d3e7-c0ef-4481-b1b2-bf15b46fb7c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(0.960066, dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc = compute_acc(p_pred, y_test)\n",
    "acc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec803934-a3f1-436e-bf5a-cd6aa6872281",
   "metadata": {},
   "source": [
    "# GP calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "040cca87-3a94-41af-b84d-f295d5c8cf1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import gpytorch\n",
    "from scipy.cluster.vq import kmeans\n",
    "from scipy.special import softmax\n",
    "from methods.calibration import CalibrationMethod\n",
    "import numpy as np\n",
    "from gpytorch.models import ApproximateGP\n",
    "\n",
    "class GPCalibration(CalibrationMethod):\n",
    "    \"\"\"\n",
    "    Probability calibration using a latent Gaussian process with PyTorch and GPyTorch\n",
    "    \"\"\"\n",
    "    def __init__(self, n_classes, logits=False, mean_function=None, kernel=None, likelihood=None,\n",
    "                 n_inducing_points=100, maxiter=1000, n_monte_carlo=100, max_samples_monte_carlo=10**7,\n",
    "                 inf_mean_approx=False, random_state=1, verbose=False):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Initialization\n",
    "        self.n_classes = n_classes\n",
    "        self.logits = logits\n",
    "        self.verbose = verbose\n",
    "        self.n_inducing_points = n_inducing_points\n",
    "        self.maxiter = maxiter\n",
    "        self.n_monte_carlo = n_monte_carlo\n",
    "        self.inf_mean_approx = inf_mean_approx\n",
    "        self.random_state = random_state\n",
    "        torch.manual_seed(self.random_state)  # for reproducibility\n",
    "\n",
    "        # Setting up likelihood\n",
    "        if likelihood is None:\n",
    "            #self.likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "            self.likelihood = MultiCal(num_classes=self.n_classes, num_monte_carlo_points=self.n_monte_carlo)\n",
    "        else:\n",
    "            self.likelihood = likelihood\n",
    "\n",
    "        # Setting up mean function\n",
    "        if mean_function is None:\n",
    "            self.mean_function = IdentityMean() if logits else LogMean()\n",
    "        else:\n",
    "            self.mean_function = mean_function\n",
    "\n",
    "        # Setting up kernel\n",
    "        if kernel is None:\n",
    "            k_white = WhiteNoiseKernel(variance=0.01)\n",
    "             # Define the RBF kernel with specific lengthscale\n",
    "            if logits:\n",
    "                kernel_lengthscale = 10.0\n",
    "            else:\n",
    "                kernel_lengthscale = 0.5\n",
    "\n",
    "            # RBF kernel\n",
    "            k_rbf = gpytorch.kernels.ScaleKernel(\n",
    "                gpytorch.kernels.RBFKernel(lengthscale=kernel_lengthscale),\n",
    "                outputscale=1.0\n",
    "            )\n",
    "\n",
    "            if not logits:\n",
    "                # Set constraints on kernel parameters\n",
    "                k_rbf.base_kernel.lengthscale = gpytorch.constraints.Interval(0.001, 10)\n",
    "                k_rbf.outputscale = gpytorch.constraints.Interval(0.0, 5.0)\n",
    "            \n",
    "            # Combine kernels\n",
    "            self.kernel = k_rbf + k_white\n",
    "        \n",
    "        else:\n",
    "            self.kernel = kernel\n",
    "\n",
    "        # Initialize GPyTorch model and likelihood\n",
    "        self.model = None\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        # Handling input shapes\n",
    "        if X.ndim != 2 or X.shape[1] != self.n_classes:\n",
    "            raise ValueError(\"Calibration data must have shape (n_samples, n_classes).\")\n",
    "        \n",
    "        # Setting device\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        \n",
    "        # Convert data to tensor\n",
    "        X = torch.tensor(X, dtype=torch.float32, device=device)\n",
    "        y = torch.tensor(y, dtype=torch.float32, device=device)\n",
    "\n",
    "        # Inducing points via k-means\n",
    "        Z = torch.tensor(kmeans(X.view(-1, 1).cpu().numpy(), self.n_inducing_points)[0], dtype=torch.float32, device=device)\n",
    "        \n",
    "        # Define and initialize the model\n",
    "        #self.model = gpytorch.models.ApproximateGP(X, y, Z, self.mean_function, self.kernel, self.likelihood)\n",
    "        self.model = GPModel(Z, self.mean_function, self.kernel)\n",
    "        self.model.train()\n",
    "        self.likelihood.train()\n",
    "\n",
    "        # Use the Adam optimizer\n",
    "        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)\n",
    "\n",
    "        # Our loss object. We're using the VariationalELBO\n",
    "        mll = gpytorch.mlls.VariationalELBO(self.likelihood, self.model, num_data=y.size(0))\n",
    "\n",
    "        # Optimization loop\n",
    "        for i in tqdm(range(self.maxiter)):\n",
    "            optimizer.zero_grad()\n",
    "            output = self.model(X)\n",
    "            loss = -mll(output, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if self.verbose:\n",
    "                print(f\"Iteration {i+1}/{self.maxiter}, Loss: {loss.item()}\")\n",
    "\n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X, mean_approx=False):\n",
    "        # Check model is fitted\n",
    "        if self.model is None:\n",
    "            raise RuntimeError(\"The model must be fitted before prediction.\")\n",
    "        \n",
    "        # Evaluate\n",
    "        self.model.eval()\n",
    "        self.likelihood.eval()\n",
    "        \n",
    "        # Setup data\n",
    "        X = torch.tensor(X, dtype=torch.float32)\n",
    "        \n",
    "        with torch.no_grad(), gpytorch.settings.fast_pred_var():\n",
    "            if mean_approx or self.inf_mean_approx:\n",
    "                preds = self.model(X)\n",
    "                return softmax(preds.mean.cpu().numpy(), axis=1)\n",
    "            else:\n",
    "                # Full predictive distribution\n",
    "                preds = self.likelihood(self.model(X))\n",
    "                return preds.mean.cpu().numpy()  # For simplicity, returning mean only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "d40b4765-5f93-4c81-9636-aba077d5eb3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import ApproximateGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution\n",
    "from gpytorch.variational import VariationalStrategy\n",
    "\n",
    "class GPModel(ApproximateGP):\n",
    "    def __init__(self, inducing_points, mean_function, kernel):\n",
    "        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = mean_function\n",
    "        self.covar_module = kernel\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6e37b1-bd6c-422b-a825-51aed353f343",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "df1921bf-fd52-478c-b89b-7f187b1a0f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.tensor(p_pred_train, dtype=torch.float32, device='cpu')\n",
    "y = torch.tensor(y_y_pred, dtype=torch.float32, device='cpu')\n",
    "\n",
    "# Inducing points via k-means\n",
    "Z = torch.tensor(kmeans(X.view(-1, 1).cpu().numpy(), 10)[0], dtype=torch.float32, device='cpu')\n",
    "#Z = torch.tensor(kmeans(X.cpu().numpy(), 100)[0], dtype=torch.float32, device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "f4071da7-33fc-46e2-ab26-3d0bddd6e100",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define and initialize the model\n",
    "mean_function = IdentityMean()\n",
    "kernel_lengthscale = 10.0\n",
    "# RBF kernel\n",
    "k_white = WhiteNoiseKernel(variance=0.01)\n",
    "k_rbf = gpytorch.kernels.ScaleKernel(\n",
    "    gpytorch.kernels.RBFKernel(lengthscale=kernel_lengthscale),\n",
    "    outputscale=1.0)\n",
    "\n",
    "kernel = k_rbf + k_white\n",
    "\n",
    "model = GPModel(Z, mean_function, kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "df90e927-38bc-4da1-97bd-4fd96840525e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9091],\n",
       "        [0.9879],\n",
       "        [0.5467],\n",
       "        [0.7996],\n",
       "        [0.9618],\n",
       "        [0.2348],\n",
       "        [0.0026],\n",
       "        [0.9989],\n",
       "        [0.0980],\n",
       "        [0.0357]])"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "29a33500-942c-463e-87d8-475e0fa052a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import MultivariateNormal, Normal\n",
    "\n",
    "class WhiteNoiseKernel(gpytorch.kernels.Kernel):\n",
    "    has_lengthscale = False\n",
    "\n",
    "    def __init__(self, variance=0.01, **kwargs):\n",
    "        super(WhiteNoiseKernel, self).__init__(**kwargs)\n",
    "        self.register_parameter(\n",
    "            name='noise', \n",
    "            parameter=torch.nn.Parameter(torch.tensor(variance))\n",
    "        )\n",
    "\n",
    "    def forward(self, x1, x2, diag=False, **params):\n",
    "        if diag:\n",
    "            return self.noise.expand(x1.size(0))\n",
    "        else:\n",
    "            return self.noise * torch.eye(x1.size(0), device=x1.device, dtype=x1.dtype)\n",
    "\n",
    "class IdentityMean(gpytorch.means.Mean):\n",
    "    \"\"\"\n",
    "    An identity mean function that simply returns the inputs as the mean.\n",
    "    This can be particularly useful for cases where the input itself is a suitable\n",
    "    prior mean estimate for the outputs under Gaussian Process modeling.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super(IdentityMean, self).__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        The forward pass returns the input itself.\n",
    "\n",
    "        Parameters:\n",
    "            x (torch.Tensor): The input features to the GP model.\n",
    "\n",
    "        Returns:\n",
    "            torch.Tensor: Outputs the input x directly as the mean.\n",
    "        \"\"\"\n",
    "        return x.reshape(-1)\n",
    "\n",
    "\n",
    "class LogMean(gpytorch.means.Mean):\n",
    "    \"\"\"\n",
    "    Natural logarithm prior mean function. Computes the natural logarithm of inputs,\n",
    "    carefully avoiding logarithm of zero by clipping values close to zero up to a small positive number.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LogMean, self).__init__()\n",
    "\n",
    "    def forward(self, X):\n",
    "        # Avoid -inf = log(0)\n",
    "        tiny = torch.finfo(X.dtype).tiny\n",
    "        X_clipped = torch.clamp(X, min=tiny)\n",
    "        # Returns the natural logarithm of the clipped input\n",
    "        return torch.log(X_clipped).reshape(-1)\n",
    "\n",
    "class ScalarMultMean(gpytorch.means.Mean):\n",
    "    \"\"\"\n",
    "    Scalar multiplication mean function. Multiplies input by a scalar parameter alpha.\n",
    "\n",
    "    :math:`y_i = \\\\alpha x_i`\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, alpha=1.0):\n",
    "        super(ScalarMultMean, self).__init__()\n",
    "        self.alpha = gpytorch.Parameter(torch.tensor(alpha))  # Create a learnable parameter\n",
    "\n",
    "    def forward(self, X):\n",
    "        # Scalar multiplication\n",
    "        return (self.alpha * X).reshape(-1)\n",
    "\n",
    "\n",
    "class SoftArgMax(gpytorch.Module):\n",
    "    \"\"\"\n",
    "    This class implements the multi-class softargmax inverse-link function. Given a vector :math:`f=[f_1, f_2, ... f_k]`,\n",
    "    the result of the mapping is :math:`y = [y_1 ... y_k]`, where\n",
    "    :math:`y_i = \\\\frac{\\\\exp(f_i)}{\\\\sum_{j=1}^k\\\\exp(f_j)}`.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(SoftArgMax, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "    def forward(self, F):\n",
    "        \"\"\"\n",
    "        Apply the softmax function to input tensor F along the last dimension.\n",
    "        \n",
    "        Parameters:\n",
    "            F (torch.Tensor): The input tensor containing logits for each class.\n",
    "\n",
    "        Returns:\n",
    "            torch.Tensor: The softmax probabilities for each class.\n",
    "        \"\"\"\n",
    "        return torch.nn.functional.softmax(F, dim=-1)\n",
    "        \n",
    "\n",
    "class MultiCal(gpytorch.likelihoods._OneDimensionalLikelihood):\n",
    "    def __init__(self, num_classes, invlink=None, num_monte_carlo_points=100):\n",
    "        \"\"\"\n",
    "        A likelihood for multiclass calibration using the softargmax link function and a single latent process.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.num_monte_carlo_points = num_monte_carlo_points\n",
    "        \n",
    "        if invlink is None:\n",
    "            self.invlink = SoftArgMax(self.num_classes)\n",
    "        elif not isinstance(invlink, SoftArgMax):\n",
    "            raise NotImplementedError(\"Only SoftArgMax is implemented as an invlink function.\")\n",
    "        else:\n",
    "            self.invlink = invlink\n",
    "\n",
    "    def expected_log_prob(self, target, input, *params, **kwargs):\n",
    "        \"\"\"\n",
    "        Computes the expected log probability under the likelihood given input mean and variance.\n",
    "        \"\"\"\n",
    "        mean, variance = input.mean, input.variance\n",
    "        softmax_probs = torch.nn.functional.softmax(mean, dim=1)\n",
    "        log_probs = torch.log(softmax_probs)\n",
    "\n",
    "        target = target.long().view(-1)\n",
    "        return log_probs[range(target.size(0)), target]\n",
    "\n",
    "    def variational_expectations(self, mean, variance, observations, *params, **kwargs):\n",
    "        \"\"\"\n",
    "        Computes the variational expectations for the likelihood, which can be used in ELBO calculation.\n",
    "        \"\"\"\n",
    "        softmax_probs = torch.nn.functional.softmax(mean, dim=1)\n",
    "        \n",
    "        # For simplification, we only use diagonal variance here\n",
    "        log_probs = torch.log(softmax_probs)\n",
    "        expected_log_probs = log_probs[range(observations.size(0)), observations.long().view(-1)]\n",
    "        \n",
    "        # Approximation of the second term using a Taylor expansion (or other approximations)\n",
    "        term2 = variance.pow(2) * softmax_probs * (1 - softmax_probs)\n",
    "        \n",
    "        return expected_log_probs + 0.5 * term2.sum(dim=1)\n",
    "\n",
    "    def forward(self, function_samples, *params, **kwargs):\n",
    "        \"\"\"\n",
    "        Returns a sample from the likelihood given function values.\n",
    "        \"\"\"\n",
    "        probs = torch.nn.functional.softmax(function_samples, dim=-1)\n",
    "        cat_dist = torch.distributions.Categorical(probs=probs)\n",
    "        return cat_dist.sample()\n",
    "\n",
    "    def log_prob(self, value, mean, variance, *params, **kwargs):\n",
    "        \"\"\"\n",
    "        Compute log probability of `value` given a normal distribution with specified mean and variance.\n",
    "        \"\"\"\n",
    "        probs = torch.nn.functional.softmax(mean, dim=1)\n",
    "        value = value.long()\n",
    "        return torch.log(probs[range(value.size(0)), value])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3a2759-f41a-408b-9ca4-22dd55f59a67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "2c2c050f-7700-4230-ab4c-01435fb021f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f151325e9974467fb9d2d5026664ef3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "RuntimeError",
     "evalue": "Sizes of tensors must match except in dimension 0. Expected size 1 but got size 2 for tensor number 1 in the list.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[139], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m recal_f \u001b[38;5;241m=\u001b[39m GPCalibration(n_classes\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, n_monte_carlo\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mrecal_f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_y_pred\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[122], line 101\u001b[0m, in \u001b[0;36mGPCalibration.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m     99\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmaxiter)):\n\u001b[1;32m    100\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m--> 101\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    102\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mmll(output, y)\n\u001b[1;32m    103\u001b[0m     loss\u001b[38;5;241m.\u001b[39mbackward()\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/models/approximate_gp.py:108\u001b[0m, in \u001b[0;36mApproximateGP.__call__\u001b[0;34m(self, inputs, prior, **kwargs)\u001b[0m\n\u001b[1;32m    106\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m    107\u001b[0m     inputs \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 108\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvariational_strategy\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprior\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprior\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/variational/variational_strategy.py:272\u001b[0m, in \u001b[0;36mVariationalStrategy.__call__\u001b[0;34m(self, x, prior, **kwargs)\u001b[0m\n\u001b[1;32m    269\u001b[0m         \u001b[38;5;66;03m# Mark that we have updated the variational strategy\u001b[39;00m\n\u001b[1;32m    270\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdated_strategy\u001b[38;5;241m.\u001b[39mfill_(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 272\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprior\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprior\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/variational/_variational_strategy.py:341\u001b[0m, in \u001b[0;36m_VariationalStrategy.__call__\u001b[0;34m(self, x, prior, **kwargs)\u001b[0m\n\u001b[1;32m    339\u001b[0m \u001b[38;5;66;03m# Get q(f)\u001b[39;00m\n\u001b[1;32m    340\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(variational_dist_u, MultivariateNormal):\n\u001b[0;32m--> 341\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m    342\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    343\u001b[0m \u001b[43m        \u001b[49m\u001b[43minducing_points\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    344\u001b[0m \u001b[43m        \u001b[49m\u001b[43minducing_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvariational_dist_u\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    345\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvariational_inducing_covar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvariational_dist_u\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy_covariance_matrix\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    346\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    347\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    348\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(variational_dist_u, Delta):\n\u001b[1;32m    349\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m(\n\u001b[1;32m    350\u001b[0m         x, inducing_points, inducing_values\u001b[38;5;241m=\u001b[39mvariational_dist_u\u001b[38;5;241m.\u001b[39mmean, variational_inducing_covar\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    351\u001b[0m     )\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/module.py:31\u001b[0m, in \u001b[0;36mModule.__call__\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39minputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tensor, Distribution, LinearOperator]:\n\u001b[0;32m---> 31\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     32\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(outputs, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m     33\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m [_validate_module_outputs(output) \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m outputs]\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/gpytorch/variational/variational_strategy.py:189\u001b[0m, in \u001b[0;36mVariationalStrategy.forward\u001b[0;34m(self, x, inducing_points, inducing_values, variational_inducing_covar, **kwargs)\u001b[0m\n\u001b[1;32m    180\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m    181\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    182\u001b[0m     x: Tensor,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    187\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m MultivariateNormal:\n\u001b[1;32m    188\u001b[0m     \u001b[38;5;66;03m# Compute full prior distribution\u001b[39;00m\n\u001b[0;32m--> 189\u001b[0m     full_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43minducing_points\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    190\u001b[0m     full_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mforward(full_inputs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    191\u001b[0m     full_covar \u001b[38;5;241m=\u001b[39m full_output\u001b[38;5;241m.\u001b[39mlazy_covariance_matrix\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Sizes of tensors must match except in dimension 0. Expected size 1 but got size 2 for tensor number 1 in the list."
     ]
    }
   ],
   "source": [
    "recal_f = GPCalibration(n_classes=2, n_monte_carlo=100)\n",
    "recal_f.fit(p_pred, y_y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "3df5bd51-5b1c-4df3-8ec7-33db6f12a97a",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[117], line 60\u001b[0m\n\u001b[1;32m     57\u001b[0m model \u001b[38;5;241m=\u001b[39m SVGPcal(train_x, train_y, likelihood, inducing_points\u001b[38;5;241m=\u001b[39minducing_points)\n\u001b[1;32m     58\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39moptim\u001b[38;5;241m.\u001b[39mAdam(model\u001b[38;5;241m.\u001b[39mparameters(), lr\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.01\u001b[39m)\n\u001b[0;32m---> 60\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining_iterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m500\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[117], line 46\u001b[0m, in \u001b[0;36mSVGPcal.train\u001b[0;34m(self, optimizer, training_iterations)\u001b[0m\n\u001b[1;32m     44\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m(batch_x)\n\u001b[1;32m     45\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlikelihood(train_x,train_y)\u001b[38;5;241m.\u001b[39msample()\u001b[38;5;241m.\u001b[39msum()\n\u001b[0;32m---> 46\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     47\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m     49\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m50\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    483\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    484\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    485\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    490\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    491\u001b[0m     )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    493\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.6/lib/python3.11/site-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    246\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    252\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    253\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    254\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    255\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    256\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    257\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    258\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
     ]
    }
   ],
   "source": [
    "from gpytorch.kernels import ScaleKernel, RBFKernel\n",
    "from gpytorch.means import ZeroMean\n",
    "\n",
    "class SVGPcal(AbstractVariationalGP):\n",
    "    def __init__(self, train_x, train_y, likelihood, kernel=None, inducing_points=None, num_latent=1, q_diag=False, whiten=True):\n",
    "        self.train_inputs = train_x\n",
    "        self.train_targets = train_y\n",
    "        # Initialize variational distribution and strategy\n",
    "        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0), batch_shape=torch.Size([num_latent]))\n",
    "\n",
    "        # Depending on q_diag flag, you can use different distributions, here we use Cholesky as default\n",
    "        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)\n",
    "\n",
    "        super().__init__(variational_strategy)\n",
    "        self.likelihood = likelihood\n",
    "        \n",
    "        # Set the kernel\n",
    "        self.kernel = kernel if kernel is not None else ScaleKernel(RBFKernel())\n",
    "        \n",
    "        # Mean function\n",
    "        self.mean_module = ZeroMean()\n",
    "\n",
    "        # This can handle a batch shape which GPflow handles with num_latent\n",
    "        self.num_latent = num_latent\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.kernel(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "    def _get_batch_idx(self, batch_size):\n",
    "        valid_indices = torch.arange(0, self.train_inputs[0].size(0))\n",
    "        batch_indices = valid_indices[torch.randperm(len(valid_indices))[:batch_size]]\n",
    "        return batch_indices\n",
    "\n",
    "    def train(self, optimizer, training_iterations=1):\n",
    "        self.likelihood.train()\n",
    "        \n",
    "        for i in range(training_iterations):\n",
    "            optimizer.zero_grad()\n",
    "            sample_idx = self._get_batch_idx(batch_size=64)\n",
    "            batch_x = self.train_inputs[0][sample_idx]\n",
    "            batch_y = self.train_targets[sample_idx]\n",
    "            output = self(batch_x)\n",
    "            loss = - self.likelihood(train_x,train_y).sample().sum()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if (i + 1) % 50 == 0:\n",
    "                print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))\n",
    "\n",
    "# Usage\n",
    "train_x = torch.randn(1000, 1)\n",
    "train_y = torch.randn(1000)\n",
    "inducing_points = train_x[:50]  # Use 50 inducing points\n",
    "likelihood = GaussianLikelihood()\n",
    "model = SVGPcal(train_x, train_y, likelihood, inducing_points=inducing_points)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "model.train(optimizer, training_iterations=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00fa0b9d-256b-4159-8f2a-a79dc8e63ee3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "af5f5acd-5795-45a8-a1fc-77dddd98951f",
   "metadata": {},
   "source": [
    "# Calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "c82fcb63-48e4-4cf7-8304-abd443078304",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_recab_ECE(confidences, labels, conf_cal, labels_cal, n_bins, norm='l1', strategy='label'):\n",
    "    \"\"\"\n",
    "    Calcurating recalibrate ECE with calibration dataset.\n",
    "    \"\"\"\n",
    "    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        confidences = confidences.softmax(1)\n",
    "    \n",
    "    if not torch.all((confidences >= 0) & (confidences <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    confidences, _ = confidences.max(dim=1)\n",
    "    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "    if not torch.all(torch.abs(torch.sum(conf_cal, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        conf_cal = conf_cal.softmax(1)\n",
    "    \n",
    "    if not torch.all((conf_cal >= 0) & (conf_cal <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    conf_cal, _ = conf_cal.max(dim=1)\n",
    "    conf_cal[labels_cal==0] = 1 - conf_cal[labels_cal==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)\n",
    "        \n",
    "        #idx = torch.bucketize(conf_cal, n_bins, right=True) - 1\n",
    "        idx = idx_bins(conf_cal, n_bins)\n",
    "        bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "        if strategy == 'label':\n",
    "            bin_true = (torch.bincount(idx, weights=labels_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "        elif strategy == 'probability':\n",
    "            bin_true = (torch.bincount(idx, weights=conf_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "        else:\n",
    "            raise ValueError(f\"Unexpected strategy: {strategy}.\")\n",
    "            \n",
    "        with warnings.catch_warnings():\n",
    "            warnings.filterwarnings('ignore')\n",
    "            # fill nan by interpolation assuming smoothness\n",
    "            bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        \n",
    "        ## prediction based on the recalibrate function\n",
    "        idx = idx_bins(confidences, n_bins)\n",
    "        confidences = bin_mean[idx]\n",
    "\n",
    "        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))\n",
    "        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))\n",
    "        conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "        prop_bin = count_bin / count_bin.sum()\n",
    "        \n",
    "        label_bin.scatter_add_(dim=0, index=idx, src=labels)\n",
    "        label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "    \n",
    "    if norm == 'l1':\n",
    "        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)\n",
    "    elif norm == 'l2':\n",
    "        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected norm type: {norm}\")\n",
    "    \n",
    "    return ece\n",
    "\n",
    "def calc_ECE(confidences, labels, n_bins, norm='l1', recalibrate=False, strategy='label'):\n",
    "    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        confidences = confidences.softmax(1)\n",
    "    \n",
    "    if not torch.all((confidences >= 0) & (confidences <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    confidences, _ = confidences.max(dim=1)\n",
    "    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)\n",
    "        \n",
    "        #idx = torch.bucketize(confidences, n_bins, right=True) - 1\n",
    "        idx = idx_bins(confidences, n_bins)\n",
    "        if recalibrate:\n",
    "            bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "            if strategy == 'label':\n",
    "                bin_true = (torch.bincount(idx, weights=labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "            elif strategy == 'probability':\n",
    "                bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "            else:\n",
    "                raise ValueError(f\"Unexpected strategy: {strategy}.\")\n",
    "            \n",
    "            with warnings.catch_warnings():\n",
    "                warnings.filterwarnings('ignore')\n",
    "                # fill nan by interpolation assuming smoothness\n",
    "                bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "                confidences = bin_mean[idx]\n",
    "\n",
    "        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))\n",
    "        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))\n",
    "        conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "        prop_bin = count_bin / count_bin.sum()\n",
    "        \n",
    "        label_bin.scatter_add_(dim=0, index=idx, src=labels)\n",
    "        label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "    \n",
    "    if norm == 'l1':\n",
    "        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)\n",
    "    elif norm == 'l2':\n",
    "        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected norm type: {norm}\")\n",
    "    \n",
    "    return ece\n",
    "\n",
    "#def compute_ece(preds, mask, dataset, n_bins, norm='l1', recalibrate=False, strategy='label'):\n",
    "def compute_ece(preds, mask, labels, n_bins, norm='l1', recalibrate=False, strategy='label', cal_data=False, preds_cal=None, labels_cal=None):\n",
    "    preds = torch.tensor(preds)\n",
    "    labels = torch.tensor(labels).long()\n",
    "    indices = 2*np.arange(len(mask)) + mask\n",
    "    if cal_data:\n",
    "        ece = calc_recab_ECE(preds[indices], labels[indices], preds_cal, labels_cal, n_bins, norm=norm, strategy=strategy)\n",
    "    else:\n",
    "        ece = calc_ECE(preds[indices], labels[indices], n_bins, norm=norm, recalibrate=recalibrate, strategy=strategy)\n",
    "    return utils.to_numpy(ece)\n",
    "\n",
    "def compute_bins(num_bins, confidences=None, method='uniform'):\n",
    "    if method == 'uniform':\n",
    "        n_bins = torch.linspace(0, 1, num_bins + 1)\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    elif method == 'quantile':\n",
    "        if confidences.all() == None:\n",
    "            raise ValueError(f\"confidence values are needed.\")\n",
    "        n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, num_bins + 1)))\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected binning method: {method}\")\n",
    "    \n",
    "    return n_bins\n",
    "\n",
    "def idx_bins(confidence, n_bins):\n",
    "    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)\n",
    "    binids -= 1\n",
    "    return torch.tensor(binids)\n",
    "\n",
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    Slightly modified From the code in the \"minimum-calibration...\" NeurIPS2023.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b).float()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9edc57f-4544-41de-9c1e-09f03ba8f145",
   "metadata": {},
   "source": [
    "## Non-recalibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "id": "f475baba-da41-4247-8034-57abc2c0fac2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.06119911618903185\n",
      "0.040410167157689136\n"
     ]
    }
   ],
   "source": [
    "num_bins = 100\n",
    "conf = np.max(p_pred_train,1)\n",
    "bins_uwb = compute_bins(num_bins=num_bins)\n",
    "bins_umb = compute_bins(num_bins=num_bins, confidences=conf, method='quantile')\n",
    "gap_ece_uwb = np.abs(compute_ece(p_pred_train, y_train, n_bins=bins_uwb) - compute_ece(p_pred, y_test, n_bins=bins_uwb))\n",
    "gap_ece_umb = np.abs(compute_ece(p_pred_train, y_train, n_bins=bins_umb) - compute_ece(p_pred, y_test, n_bins=bins_umb))\n",
    "#print(compute_ece(p_pred, y_test, n_bins=bins_uwb))\n",
    "#print(compute_ece(p_pred, y_test, n_bins=bins_umb))\n",
    "print(gap_ece_uwb)\n",
    "print(gap_ece_umb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 294,
   "id": "e873a707-715c-4c8b-8975-363f9a4e935d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.058672505657443155\n",
      "0.03977561523345215\n"
     ]
    }
   ],
   "source": [
    "num_bins = int(len(X_train) ** (1/3))\n",
    "bins_uwb = compute_bins(num_bins=num_bins)\n",
    "bins_umb = compute_bins(num_bins=num_bins, confidences=conf, method='quantile')\n",
    "gap_ece_uwb = np.abs(compute_ece(p_pred_train, y_train, n_bins=bins_uwb) - compute_ece(p_pred, y_test, n_bins=bins_uwb))\n",
    "gap_ece_umb = np.abs(compute_ece(p_pred_train, y_train, n_bins=bins_umb) - compute_ece(p_pred, y_test, n_bins=bins_umb))\n",
    "#print(compute_ece(p_pred, y_test, n_bins=bins_uwb))\n",
    "#print(compute_ece(p_pred, y_test, n_bins=bins_umb))\n",
    "print(gap_ece_uwb)\n",
    "print(gap_ece_umb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0e7f03e-0ac8-4fe6-9fcb-6b0be8b3ceb6",
   "metadata": {},
   "source": [
    "## Recalibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 301,
   "id": "0f37fbb6-aee8-46d0-961c-62740f6b3579",
   "metadata": {},
   "outputs": [],
   "source": [
    "# preparing the prediction on the calibration data.\n",
    "train_size = 1000 ## recalibration\n",
    "test_size = 8000\n",
    "if test_size > 1:\n",
    "    ## absolute value\n",
    "    calib_size_tmp = train_size\n",
    "    test_size_tmp = test_size\n",
    "else:\n",
    "    ## ratio\n",
    "    calib_size_tmp = int(np.shape(X)[0] * train_size)\n",
    "    test_size_tmp = np.shape(X)[0] - int(np.shape(X)[0] * train_size)\n",
    "\n",
    "n_splits = 10\n",
    "cv = ms.ShuffleSplit(n_splits=n_splits, test_size=test_size, train_size=train_size,\n",
    "                             random_state=sklearn.utils.check_random_state(random_state))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "id": "83143c31-8268-48eb-84a6-869c8f4c562e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0009351562499999999 0.0004092146930879468\n",
      "9.313225741991448e-12 1.2605580936878598e-11\n",
      "0.022519885861268265 0.004873780009573754\n",
      "5.332054574863545e-10 3.9127601053684666e-10\n"
     ]
    }
   ],
   "source": [
    "num_bins = 100\n",
    "#conf = np.max(p_pred,1)\n",
    "conf = np.max(p_pred_train,1)\n",
    "bins_uwb = compute_bins(num_bins=num_bins)\n",
    "bins_umb = compute_bins(num_bins=num_bins, confidences=conf, method='quantile')\n",
    "\n",
    "ece_gap_uwb = []\n",
    "ece_gap_umb = []\n",
    "ece_gap_uwb_proposed = []\n",
    "ece_gap_umb_proposed = []\n",
    "\n",
    "for i, (recal_idx, test_idx) in enumerate(cv.split(p_pred, y_test)):\n",
    "    pred_cal, label_cal = p_pred[recal_idx], y_test[recal_idx]\n",
    "    pred_tes, l_tes = p_pred[test_idx], y_test[test_idx]\n",
    "    # UWB recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_uwb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_uwb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_gap_uwb.append(np.abs(ece_tr - ece_tes))\n",
    "    # UMB recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_gap_umb.append(np.abs(ece_tr - ece_tes))\n",
    "    \n",
    "    # UWB proposed recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_uwb, recalibrate=True)\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_uwb, recalibrate=True)\n",
    "    ece_gap_uwb_proposed.append(np.abs(ece_tr - ece_tes))\n",
    "    \n",
    "    # UMB proposed recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_umb, recalibrate=True)\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_umb, recalibrate=True)\n",
    "    ece_gap_umb_proposed.append(np.abs(ece_tr - ece_tes))\n",
    "\n",
    "print(np.array(ece_gap_uwb).mean(), np.array(ece_gap_uwb).std())\n",
    "print(np.array(ece_gap_uwb_proposed).mean(),np.array(ece_gap_uwb_proposed).std())\n",
    "print(np.array(ece_gap_umb).mean(), np.array(ece_gap_umb).std())\n",
    "print(np.array(ece_gap_umb_proposed).mean(), np.array(ece_gap_umb_proposed).std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff65004c-9098-4188-bf94-601539c31e9e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 303,
   "id": "3478a3b9-a17e-4c79-bb7a-0b57423c719a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0008525183630175888\n",
      "4.6566128775182845e-11\n",
      "0.020432157210237344\n",
      "1.0126270350805171e-09\n"
     ]
    }
   ],
   "source": [
    "num_bins = int(len(X_train) ** (1/3))\n",
    "conf = np.max(p_pred_train,1)\n",
    "bins_uwb = compute_bins(num_bins=num_bins)\n",
    "bins_umb = compute_bins(num_bins=num_bins, confidences=conf, method='quantile')\n",
    "\n",
    "ece_gap_uwb = []\n",
    "ece_gap_umb = []\n",
    "ece_gap_uwb_proposed = []\n",
    "ece_gap_umb_proposed = []\n",
    "\n",
    "for i, (recal_idx, test_idx) in enumerate(cv.split(p_pred, y_test)):\n",
    "    pred_cal, label_cal = p_pred[recal_idx], y_test[recal_idx]\n",
    "    pred_tes, l_tes = p_pred[test_idx], y_test[test_idx]\n",
    "    pred_cal, label_cal = p_pred[recal_idx], y_test[recal_idx]\n",
    "    pred_tes, l_tes = p_pred[test_idx], y_test[test_idx]\n",
    "    # UWB recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_uwb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_uwb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_gap_uwb.append(np.abs(ece_tr - ece_tes))\n",
    "    # UMB recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(pred_cal), labels_cal=torch.tensor(label_cal).long())\n",
    "    ece_gap_umb.append(np.abs(ece_tr - ece_tes))\n",
    "    \n",
    "    # UWB proposed recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_uwb, recalibrate=True)\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_uwb, recalibrate=True)\n",
    "    ece_gap_uwb_proposed.append(np.abs(ece_tr - ece_tes))\n",
    "    \n",
    "    # UMB proposed recalibration\n",
    "    ece_tr = compute_ece(p_pred_train, y_train, n_bins=bins_umb, recalibrate=True)\n",
    "    ece_tes = compute_ece(pred_tes, l_tes, n_bins=bins_umb, recalibrate=True)\n",
    "    ece_gap_umb_proposed.append(np.abs(ece_tr - ece_tes))\n",
    "\n",
    "print(np.array(ece_gap_uwb).mean())\n",
    "print(np.array(ece_gap_uwb_proposed).mean())\n",
    "print(np.array(ece_gap_umb).mean())\n",
    "print(np.array(ece_gap_umb_proposed).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 258,
   "id": "569b2a69-b3d5-4ed4-800f-d11e47c0642e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.max(\n",
       "values=tensor([0.9912, 0.9982, 0.9998,  ..., 0.9986, 1.0000, 0.9996]),\n",
       "indices=tensor([0, 0, 0,  ..., 0, 0, 0]))"
      ]
     },
     "execution_count": 258,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p, ind = torch.tensor(p_pred_train).max(1)\n",
    "torch.tensor(p_pred_train).max(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 259,
   "id": "b4d49459-acd9-4dd3-a587-66f1ccbdc4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "p[y_train==0] = 1 - p[y_train==0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "id": "80a6f072-a0cb-4dc9-83f6-63be5afc9589",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])"
      ]
     },
     "execution_count": 266,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#torch.unique(idx_bins(p, bins_uwb))\n",
    "torch.unique(idx_bins(p, bins_umb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c935269a-5e4d-4608-955e-cb9a568093b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "66a053e2-5ef3-4511-b2a1-f1e267004c4c",
   "metadata": {},
   "source": [
    "# Experimental code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "26c603df-d87e-4379-995b-3b8cb8b1da68",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = ['kitti_all_train.data', 'kitti_all_train.labels', 'kitti_all_test.data', 'kitti_all_test.labels']\n",
    "file_path = os.getcwd() + '/kitti_features/'\n",
    "\n",
    "X_train = np.loadtxt(os.path.join(file_path, files[0]), np.float64, skiprows=1)\n",
    "y_train = np.loadtxt(os.path.join(file_path, files[1]), np.int32, skiprows=1)\n",
    "X_test = np.loadtxt(os.path.join(file_path, files[2]), np.float64, skiprows=1)\n",
    "y_test = np.loadtxt(os.path.join(file_path, files[3]), np.int32, skiprows=1)\n",
    "\n",
    "y_train = np.where(y_train > 0, 1, 0)\n",
    "y_test = np.where(y_test > 0, 1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "54dede21-625a-45e3-a23b-43fc2a4a8be0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "files = [\"camelyonpatch_level_2_split_valid_x.h5\", \"camelyonpatch_level_2_split_valid_y.h5\", \"camelyonpatch_level_2_split_test_x.h5\", \"camelyonpatch_level_2_split_test_y.h5\"]\n",
    "file_path = os.getcwd() + '/pcam/'\n",
    "\n",
    "with h5py.File(os.path.join(file_path, files[0]), 'r') as hf:\n",
    "    X_train = hf[\"x\"][:]\n",
    "X_train = np.dot(X_train[..., :3], [0.299, 0.587, 0.114]) ## gray scale\n",
    "X_train = X_train.reshape(np.shape(X_train)[0], -1) / 255 ## normalization\n",
    "\n",
    "with h5py.File(os.path.join(file_path, files[1]), 'r') as hf:\n",
    "    y_train = hf[\"y\"][:]\n",
    "    y_train = y_train.flatten()\n",
    "\n",
    "with h5py.File(os.path.join(file_path, files[2]), 'r') as hf:\n",
    "    X_test = hf[\"x\"][:]\n",
    "X_test = np.dot(X_test[..., :3], [0.299, 0.587, 0.114]) ## gray scale\n",
    "X_test = X_test.reshape(np.shape(X_test)[0], -1) / 255 ## normalization\n",
    "\n",
    "with h5py.File(os.path.join(file_path, files[3]), 'r') as hf:\n",
    "    y_test = hf[\"y\"][:]\n",
    "    y_test = y_test.flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "57ac6d4f-8179-47db-a74e-be1de57ff6b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(16000, 60)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e977e083-8d7c-4005-9dbf-a70676b2d274",
   "metadata": {},
   "outputs": [],
   "source": [
    "from methods.calibration import CalibrationMethod, TemperatureScaling, HistogramBinning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4250316a-9f8f-4353-9c99-a5e7e3e1b876",
   "metadata": {},
   "source": [
    "### Calibration method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "91331c6f-4913-4810-9eef-da0275e78406",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils.validation import check_is_fitted\n",
    "\n",
    "class CalibrationMethod(sklearn.base.BaseEstimator):\n",
    "    \"\"\"\n",
    "    A generic class for probability calibration\n",
    "\n",
    "    A calibration method takes a set of posterior class probabilities and transform them into calibrated posterior\n",
    "    probabilities. Calibrated in this sense means that the empirical frequency of a correct class prediction matches its\n",
    "    predicted posterior probability.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        raise NotImplementedError(\"Subclass must implement this method.\")\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        raise NotImplementedError(\"Subclass must implement this method.\")\n",
    "\n",
    "    def predict(self, X):\n",
    "        return np.argmax(self.predict_proba(X), axis=1)\n",
    "\n",
    "    def plot(self, filename, xlim=[0, 1], **kwargs):\n",
    "        # TODO: Fix this plotting function\n",
    "\n",
    "        # Generate data and transform\n",
    "        x = np.linspace(0, 1, 10000)\n",
    "        y = self.predict_proba(np.column_stack([1 - x, x]))[:, 1]\n",
    "\n",
    "        # Plot and label\n",
    "        plt.plot(x, y, **kwargs)\n",
    "        plt.xlim(xlim)\n",
    "        plt.xlabel(\"p(y=1|x)\")\n",
    "        plt.ylabel(\"f(p(y=1|x))\")\n",
    "\n",
    "class PlattScaling(CalibrationMethod):\n",
    "\n",
    "    def __init__(self, regularization=10 ** -12, random_state=None):\n",
    "        super().__init__()\n",
    "        self.regularization = regularization\n",
    "        self.random_state = sklearn.utils.check_random_state(random_state)\n",
    "\n",
    "    def fit(self, X, y, n_jobs=None):\n",
    "\n",
    "        if X.ndim == 1:\n",
    "            raise ValueError(\"Calibration training data must have shape (n_samples, n_classes).\")\n",
    "        elif np.shape(X)[1] == 2:\n",
    "            self.logistic_regressor_ = sklearn.linear_model.LogisticRegression(C=1 / self.regularization,\n",
    "                                                                               solver='lbfgs',\n",
    "                                                                               random_state=self.random_state)\n",
    "            self.logistic_regressor_.fit(X[:, 1].reshape(-1, 1), y)\n",
    "        elif np.shape(X)[1] > 2:\n",
    "            self.onevsrest_calibrator_ = OneVsRestCalibrator(calibrator=clone(self), n_jobs=n_jobs)\n",
    "            self.onevsrest_calibrator_.fit(X, y)\n",
    "\n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        if X.ndim == 1:\n",
    "            raise ValueError(\"Calibration data must have shape (n_samples, n_classes).\")\n",
    "        elif np.shape(X)[1] == 2:\n",
    "            check_is_fitted(self, \"logistic_regressor_\")\n",
    "            return self.logistic_regressor_.predict_proba(X[:, 1].reshape(-1, 1))\n",
    "        elif np.shape(X)[1] > 2:\n",
    "            check_is_fitted(self, \"onevsrest_calibrator_\")\n",
    "            return self.onevsrest_calibrator_.predict_proba(X)\n",
    "\n",
    "class TemperatureScaling(CalibrationMethod):\n",
    "    \"\"\"\n",
    "    Probability calibration using temperature scaling\n",
    "\n",
    "    Temperature scaling [1]_ is a one parameter multi-class scaling method. Output confidence scores are calibrated,\n",
    "    meaning they match empirical frequencies of the associated class prediction. Temperature scaling does not change the\n",
    "    class predictions of the underlying model.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    T_init : float\n",
    "        Initial temperature parameter used for scaling. This parameter is optimized in order to calibrate output\n",
    "        probabilities.\n",
    "    verbose : bool\n",
    "        Print information on optimization procedure.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] On calibration of modern neural networks, C. Guo, G. Pleiss, Y. Sun, K. Weinberger, ICML 2017\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, T_init=1, verbose=False):\n",
    "        super().__init__()\n",
    "        if T_init <= 0:\n",
    "            raise ValueError(\"Temperature not greater than 0.\")\n",
    "        self.T_init = T_init\n",
    "        self.verbose = verbose\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        \"\"\"\n",
    "        Fit the calibration method based on the given uncalibrated class probabilities or logits X and ground truth\n",
    "        labels y.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            Training data, i.e. predicted probabilities or logits of the base classifier on the calibration set.\n",
    "        y : array-like, shape (n_samples,)\n",
    "            Target classes.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        self : object\n",
    "            Returns an instance of self.\n",
    "        \"\"\"\n",
    "\n",
    "        # Define objective function (NLL / cross entropy)\n",
    "        def objective(T):\n",
    "            # Calibrate with given T\n",
    "            P = scipy.special.softmax(X / T, axis=1)\n",
    "\n",
    "            # Compute negative log-likelihood\n",
    "            P_y = P[np.array(np.arange(0, X.shape[0])), y]\n",
    "            tiny = np.finfo(np.float64).tiny  # to avoid division by 0 warning\n",
    "            NLL = - np.sum(np.log(P_y + tiny))\n",
    "            return NLL\n",
    "\n",
    "        # Derivative of the objective with respect to the temperature T\n",
    "        def gradient(T):\n",
    "            # Exponential terms\n",
    "            E = np.exp(X / T)\n",
    "\n",
    "            # Gradient\n",
    "            dT_i = (np.sum(E * (X - X[np.array(np.arange(0, X.shape[0])), y].reshape(-1, 1)), axis=1)) \\\n",
    "                   / np.sum(E, axis=1)\n",
    "            grad = - dT_i.sum() / T ** 2\n",
    "            return grad\n",
    "\n",
    "        # Optimize\n",
    "        self.T = scipy.optimize.fmin_bfgs(f=objective, x0=self.T_init,\n",
    "                                          fprime=gradient, gtol=1e-06, disp=self.verbose)[0]\n",
    "\n",
    "        # Check for T > 0\n",
    "        if self.T <= 0:\n",
    "            raise ValueError(\"Temperature not greater than 0.\")\n",
    "\n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Compute calibrated posterior probabilities for a given array of posterior probabilities from an arbitrary\n",
    "        classifier.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            The uncalibrated posterior probabilities.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        P : array, shape (n_samples, n_classes)\n",
    "            The predicted probabilities.\n",
    "        \"\"\"\n",
    "        # Check is fitted\n",
    "        check_is_fitted(self, \"T\")\n",
    "\n",
    "        # Transform with scaled softmax\n",
    "        return scipy.special.softmax(X / self.T, axis=1)\n",
    "\n",
    "    def latent(self, z):\n",
    "        \"\"\"\n",
    "        Evaluate the latent function Tz of temperature scaling.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        z : array-like, shape=(n_evaluations,)\n",
    "            Input confidence for which to evaluate the latent function.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        f : array-like, shape=(n_evaluations,)\n",
    "            Values of the latent function at z.\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"T\")\n",
    "        return self.T * z\n",
    "\n",
    "    def plot_latent(self, z, filename, **kwargs):\n",
    "        \"\"\"\n",
    "        Plot the latent function of the calibration method.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        z : array-like, shape=(n_evaluations,)\n",
    "            Input confidence to plot latent function for.\n",
    "        filename :\n",
    "            Filename / -path where to save output.\n",
    "        kwargs\n",
    "            Additional arguments passed on to matplotlib.pyplot.subplots.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "\n",
    "        \"\"\"\n",
    "        check_is_fitted(self, \"T\")\n",
    "\n",
    "        # Plot latent function\n",
    "        fig, axes = pycalib.texfig.subplots(nrows=1, ncols=1, sharex=True, **kwargs)\n",
    "        axes.plot(z, self.T * z, label=\"latent function\")\n",
    "        axes.set_ylabel(\"$T\\\\bm{z}$\")\n",
    "        axes.set_xlabel(\"$\\\\bm{z}_k$\")\n",
    "        fig.align_labels()\n",
    "\n",
    "        # Save plot to file\n",
    "        pycalib.texfig.savefig(filename)\n",
    "        plt.close()\n",
    "\n",
    "class HistogramBinning(CalibrationMethod):\n",
    "    \"\"\"\n",
    "    Probability calibration using histogram binning\n",
    "\n",
    "    Histogram binning [1]_ is a nonparametric approach to probability calibration. Classifier scores are binned into a given\n",
    "    number of bins either based on fixed width or frequency. Classifier scores are then computed based on the empirical\n",
    "    frequency of class 1 in each bin.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "        mode : str, default='equal_width'\n",
    "            Binning mode used. One of ['equal_width', 'equal_freq'].\n",
    "        n_bins : int, default=20\n",
    "            Number of bins to bin classifier scores into.\n",
    "        input_range : list, shape (2,), default=[0, 1]\n",
    "            Range of the classifier scores.\n",
    "\n",
    "    .. [1] Zadrozny, B. & Elkan, C. Obtaining calibrated probability estimates from decision trees and naive Bayesian\n",
    "           classifiers in Proceedings of the 18th International Conference on Machine Learning (ICML, 2001), 609–616.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, mode='equal_freq', n_bins=20, input_range=[0, 1]):\n",
    "        super().__init__()\n",
    "        if mode in ['equal_width', 'equal_freq']:\n",
    "            self.mode = mode\n",
    "        else:\n",
    "            raise ValueError(\"Mode not recognized. Choose on of 'equal_width', or 'equal_freq'.\")\n",
    "        self.n_bins = n_bins\n",
    "        self.input_range = input_range\n",
    "\n",
    "    def fit(self, X, y, n_jobs=None):\n",
    "        \"\"\"\n",
    "        Fit the calibration method based on the given uncalibrated class probabilities X and ground truth labels y.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            Training data, i.e. predicted probabilities of the base classifier on the calibration set.\n",
    "        y : array-like, shape (n_samples,)\n",
    "            Target classes.\n",
    "        n_jobs : int or None, optional (default=None)\n",
    "            The number of jobs to use for the computation.\n",
    "            ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.\n",
    "            ``-1`` means using all processors.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        self : object\n",
    "            Returns an instance of self.\n",
    "        \"\"\"\n",
    "        if X.ndim == 1:\n",
    "            raise ValueError(\"Calibration training data must have shape (n_samples, n_classes).\")\n",
    "        elif np.shape(X)[1] == 2:\n",
    "            return self._fit_binary(X, y)\n",
    "        elif np.shape(X)[1] > 2:\n",
    "            self.onevsrest_calibrator_ = OneVsRestCalibrator(calibrator=clone(self), n_jobs=n_jobs)\n",
    "            self.onevsrest_calibrator_.fit(X, y)\n",
    "        return self\n",
    "\n",
    "    def _fit_binary(self, X, y):\n",
    "        if self.mode == 'equal_width':\n",
    "            # Compute probability of class 1 in each equal width bin\n",
    "            binned_stat = scipy.stats.binned_statistic(x=X[:, 1], values=np.equal(1, y), statistic='mean',\n",
    "                                                       bins=self.n_bins, range=self.input_range)\n",
    "            self.prob_class_1 = binned_stat.statistic\n",
    "            self.binning = binned_stat.bin_edges\n",
    "        elif self.mode == 'equal_freq':\n",
    "            # Find binning based on equal frequency\n",
    "            self.binning = np.quantile(X[:, 1],\n",
    "                                       q=np.linspace(self.input_range[0], self.input_range[1], self.n_bins + 1))\n",
    "\n",
    "            # Compute probability of class 1 in equal frequency bins\n",
    "            digitized = np.digitize(X[:, 1], bins=self.binning)\n",
    "            digitized[digitized == len(self.binning)] = len(self.binning) - 1  # include rightmost edge in partition\n",
    "            self.prob_class_1 = [y[digitized == i].mean() for i in range(1, len(self.binning))]\n",
    "\n",
    "        return self\n",
    "\n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Compute calibrated posterior probabilities for a given array of posterior probabilities from an arbitrary\n",
    "        classifier.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : array-like, shape (n_samples, n_classes)\n",
    "            The uncalibrated posterior probabilities.\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        P : array, shape (n_samples, n_classes)\n",
    "            The predicted probabilities.\n",
    "        \"\"\"\n",
    "        if X.ndim == 1:\n",
    "            raise ValueError(\"Calibration data must have shape (n_samples, n_classes).\")\n",
    "        elif np.shape(X)[1] == 2:\n",
    "            check_is_fitted(self, [\"binning\", \"prob_class_1\"])\n",
    "            # Find bin of predictions\n",
    "            digitized = np.digitize(X[:, 1], bins=self.binning)\n",
    "            digitized[digitized == len(self.binning)] = len(self.binning) - 1  # include rightmost edge in partition\n",
    "            # Transform to empirical frequency of class 1 in each bin\n",
    "            p1 = np.array([self.prob_class_1[j] for j in (digitized - 1)])\n",
    "            # If empirical frequency is NaN, do not change prediction\n",
    "            p1 = np.where(np.isfinite(p1), p1, X[:, 1])\n",
    "            assert np.all(np.isfinite(p1)), \"Predictions are not all finite.\"\n",
    "\n",
    "            return np.column_stack([1 - p1, p1])\n",
    "        elif np.shape(X)[1] > 2:\n",
    "            check_is_fitted(self, \"onevsrest_calibrator_\")\n",
    "            return self.onevsrest_calibrator_.predict_proba(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "713e0b23-064a-45b3-ad08-eaf1c980ede5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_recab_ECE(confidences, labels, conf_cal, labels_cal, n_bins, norm='l1', strategy='label'):\n",
    "    \"\"\"\n",
    "    Calcurating recalibrate ECE with calibration dataset.\n",
    "    \"\"\"\n",
    "    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        confidences = confidences.softmax(1)\n",
    "    \n",
    "    if not torch.all((confidences >= 0) & (confidences <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    confidences, _ = confidences.max(dim=1)\n",
    "    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "    if not torch.all(torch.abs(torch.sum(conf_cal, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        conf_cal = conf_cal.softmax(1)\n",
    "    \n",
    "    if not torch.all((conf_cal >= 0) & (conf_cal <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    conf_cal, _ = conf_cal.max(dim=1)\n",
    "    conf_cal[labels_cal==0] = 1 - conf_cal[labels_cal==0] ## MEMO: Reverse prob. for label y=0.\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)\n",
    "        \n",
    "        idx = idx_bins(conf_cal, n_bins)\n",
    "        #bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "        #if strategy == 'label':\n",
    "        #    bin_true = (torch.bincount(idx, weights=labels_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "        #elif strategy == 'probability':\n",
    "        #    bin_true = (torch.bincount(idx, weights=conf_cal, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "        #else:\n",
    "        #    raise ValueError(f\"Unexpected strategy: {strategy}.\")\n",
    "            \n",
    "        #with warnings.catch_warnings():\n",
    "        #    warnings.filterwarnings('ignore')\n",
    "            # fill nan by interpolation assuming smoothness\n",
    "        #    bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        \n",
    "        ## prediction based on the recalibrate function\n",
    "        #idx = idx_bins(confidences, n_bins)\n",
    "        #confidences = bin_mean[idx]\n",
    "\n",
    "        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))\n",
    "        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))\n",
    "        conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "        prop_bin = count_bin / count_bin.sum()\n",
    "        \n",
    "        label_bin.scatter_add_(dim=0, index=idx, src=labels)\n",
    "        label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "    \n",
    "    if norm == 'l1':\n",
    "        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)\n",
    "    elif norm == 'l2':\n",
    "        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected norm type: {norm}\")\n",
    "    \n",
    "    return ece\n",
    "\n",
    "def calc_ECE(confidences, labels, n_bins, norm='l1', recalibrate=False, strategy='label'):\n",
    "    if not torch.all(torch.abs(torch.sum(confidences, dim=1) - 1) < 1e-10):\n",
    "        print(\"make softmax prob.\")\n",
    "        confidences = confidences.softmax(1)\n",
    "    \n",
    "    if not torch.all((confidences >= 0) & (confidences <= 1)):\n",
    "        raise ValueError(f\"This is not softmax prob.\")\n",
    "    \n",
    "    confidences, _ = confidences.max(dim=1)\n",
    "    confidences[labels==0] = 1 - confidences[labels==0] ## MEMO: Reverse prob. for label y=0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        conf_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        count_bin = torch.zeros(len(n_bins), device=confidences.device, dtype=confidences.dtype)\n",
    "        label_bin = torch.zeros(len(n_bins), device=labels.device, dtype=labels.dtype)\n",
    "        \n",
    "        idx = idx_bins(confidences, n_bins)\n",
    "        #if recalibrate:\n",
    "        #    bin_total = torch.bincount(idx, minlength=len(n_bins)-1).float().to(confidences.device) ## the number of samples per bins\n",
    "        #    if strategy == 'label':\n",
    "        #        bin_true = (torch.bincount(idx, weights=labels, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels    \n",
    "        #    elif strategy == 'probability':\n",
    "        #        bin_true = (torch.bincount(idx, weights=confidences, minlength=len(n_bins)-1)).float().to(confidences.device) ## the number of samples per bins weighted by labels\n",
    "        #    else:\n",
    "        #        raise ValueError(f\"Unexpected strategy: {strategy}.\")\n",
    "            \n",
    "        #    with warnings.catch_warnings():\n",
    "        #        warnings.filterwarnings('ignore')\n",
    "                # fill nan by interpolation assuming smoothness\n",
    "        #        bin_mean = interpolate_nan(bin_true.numpy() / bin_total.numpy()) ## \\hat{\\mu} in Eq.(9) of Sun et al. (2023)\n",
    "        #        confidences = bin_mean[idx]\n",
    "\n",
    "        count_bin.scatter_add_(dim=0, index=idx, src=torch.ones_like(confidences).type_as(count_bin))\n",
    "        conf_bin.scatter_add_(dim=0, index=idx, src=confidences.type_as(conf_bin))\n",
    "        conf_bin = torch.nan_to_num(conf_bin / count_bin)\n",
    "        prop_bin = count_bin / count_bin.sum()\n",
    "        \n",
    "        label_bin.scatter_add_(dim=0, index=idx, src=labels)\n",
    "        label_bin = torch.nan_to_num(label_bin / count_bin)\n",
    "    \n",
    "    if norm == 'l1':\n",
    "        ece = torch.sum(torch.abs(label_bin - conf_bin) * prop_bin)\n",
    "    elif norm == 'l2':\n",
    "        ece = torch.sqrt(torch.sum(torch.pow(label_bin - conf_bin, 2) * prop_bin))\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected norm type: {norm}\")\n",
    "    \n",
    "    return ece\n",
    "\n",
    "#def compute_ece(preds, mask, dataset, n_bins, norm='l1', recalibrate=False, strategy='label'):\n",
    "def compute_ece(preds, mask, labels, n_bins, norm='l1', recalibrate=False, strategy='label', cal_data=False, preds_cal=None, labels_cal=None):\n",
    "    preds = torch.tensor(preds)\n",
    "    labels = torch.tensor(labels).long()\n",
    "    indices = 2*np.arange(len(mask)) + mask\n",
    "    if cal_data:\n",
    "        ece = calc_recab_ECE(preds[indices], labels[indices], preds_cal, labels_cal, n_bins, norm=norm, strategy=strategy)\n",
    "    else:\n",
    "        ece = calc_ECE(preds[indices], labels[indices], n_bins, norm=norm, recalibrate=recalibrate, strategy=strategy)\n",
    "    return utils.to_numpy(ece)\n",
    "\n",
    "def compute_bins(num_bins, confidences=None, method='uniform'):\n",
    "    if method == 'uniform':\n",
    "        n_bins = torch.linspace(0, 1, num_bins + 1)\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    elif method == 'quantile':\n",
    "        if confidences.all() == None:\n",
    "            raise ValueError(f\"confidence values are needed.\")\n",
    "        n_bins = torch.tensor(np.quantile(confidences, torch.linspace(0, 1, num_bins + 1)))\n",
    "        n_bins[0], n_bins[-1] = 0., 1.\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected binning method: {method}\")\n",
    "    \n",
    "    return n_bins\n",
    "\n",
    "def idx_bins(confidence, n_bins):\n",
    "    binids = np.minimum(np.digitize(confidence.numpy(), n_bins), len(n_bins) - 1)\n",
    "    binids -= 1\n",
    "    return torch.tensor(binids)\n",
    "\n",
    "def interpolate_nan(a):\n",
    "    \"\"\"Linear interpolation for nan values in a 1d array.\n",
    "    Nans on the boundary are filled with the nearest non-nan value.\n",
    "    Slightly modified From the code in the \"minimum-calibration...\" NeurIPS2023.\n",
    "    \"\"\"\n",
    "    b = a.copy()\n",
    "    nans = np.isnan(b)\n",
    "    i = np.arange(len(b))\n",
    "    b[nans] = np.interp(i[nans], i[~nans], b[~nans])\n",
    "    return torch.tensor(b).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6d2ce133-d3ed-4445-85ef-71373f0c5083",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_acc(preds, mask, labels):\n",
    "    #labels = [y for x, y in dataset]\n",
    "    preds = torch.tensor(preds)\n",
    "    labels = torch.tensor(labels).long()\n",
    "    indices = 2*np.arange(len(mask)) + mask\n",
    "    acc = (preds[indices].argmax(dim=1) == labels[indices]).float().mean()\n",
    "    return utils.to_numpy(acc)\n",
    "\n",
    "def isnotnan(pred):\n",
    "    isnotnan_ind = ~np.isnan(pred).any(axis=1)\n",
    "    return pred[isnotnan_ind, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a5915e1f-6bf3-4d37-814c-649fcfdbf6ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.metrics import compute_acc, compute_acc_l2, compute_ece, compute_bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f152840c-95b8-4558-b82d-c7df20f6c0ea",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92d26df1e65349b6b1a2bd655fb0c805",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7df18492d7494199a28e3871598523df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/40 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "TypeError",
     "evalue": "compute_acc() got an unexpected keyword argument 'labels'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/_s/jh8l69lx0wz3rhnd7rx1w5gr0000gn/T/ipykernel_12394/469272573.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m         \u001b[0;31m## Accuracy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m         \u001b[0mcur_train_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_acc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcur_preds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcur_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     49\u001b[0m         \u001b[0mcur_val_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_acc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcur_preds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mcur_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m         \u001b[0mtrain_accs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcur_train_acc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: compute_acc() got an unexpected keyword argument 'labels'"
     ]
    }
   ],
   "source": [
    "cal_size = 1000 ## recalibration\n",
    "n_bins = 15\n",
    "\n",
    "random_state = 1\n",
    "S_seeds = np.arange(0,40)\n",
    "n_sample = [200, 1000, 3000, 7000] ## KITTI\n",
    "#n_sample = [500, 3000, 7000, 10000] ## Pcam\n",
    "param_recal = True\n",
    "\n",
    "train_accs = []\n",
    "val_accs = []\n",
    "\n",
    "preds = []\n",
    "masks = []\n",
    "labels = []\n",
    "\n",
    "ece_gap_umb = []\n",
    "ece_gap_umb_proposed = []\n",
    "bins_umb = []\n",
    "bins_list_umb = []\n",
    "for n in tqdm(n_sample):\n",
    "    for seed in tqdm(S_seeds):\n",
    "        np.random.seed(seed)\n",
    "        ## preparing all datasets (train/test)\n",
    "        all_indices = np.random.choice(X_train.shape[0], size=2*n, replace=False)\n",
    "        X, y = X_train[all_indices], y_train[all_indices]\n",
    "        \n",
    "        model = xgb.XGBClassifier(booster=\"gbtree\", n_estimators=100, random_state=random_state, n_jobs=-1)\n",
    "        #model = RandomForestClassifier(n_estimators=100, criterion=\"gini\", min_samples_split=2, bootstrap=True, n_jobs=-1, random_state=random_state)\n",
    "        cur_mask = np.random.randint(2, size=(n,)) ## Ber(1/2)\n",
    "        train_indices = 2*np.arange(n) + cur_mask\n",
    "        #val_indices = 2*np.arange(n) + (1-cur_mask)\n",
    "        x_tr, y_tr = X[train_indices], y[train_indices]\n",
    "        #x_val, y_val = X_train[val_indices], y_train[val_indices]\n",
    "        model.fit(X=x_tr, y=y_tr)\n",
    "        \n",
    "        cur_preds = isnotnan(model.predict_proba(X))\n",
    "        if param_recal:\n",
    "            #temp = TemperatureScaling()\n",
    "            temp = HistogramBinning(mode='equal_freq', n_bins=10)\n",
    "            temp.fit(cur_preds, y)\n",
    "            cur_preds = temp.predict_proba(cur_preds)\n",
    "        preds.append(torch.tensor(cur_preds))\n",
    "        masks.append(torch.tensor(cur_mask))\n",
    "        labels.append(torch.tensor(y))\n",
    "\n",
    "        ## Accuracy\n",
    "        cur_train_acc = compute_acc(preds=cur_preds, mask=cur_mask, labels=y)\n",
    "        cur_val_acc = compute_acc(preds=cur_preds, mask=1-cur_mask, labels=y)\n",
    "        train_accs.append(cur_train_acc)\n",
    "        val_accs.append(cur_val_acc)\n",
    "\n",
    "        ## ECE w./ recalibration (UMB)\n",
    "        conf = torch.tensor(cur_preds[train_indices]).softmax(1).max(1).values\n",
    "        cur_bins_umb = compute_bins(num_bins=n_bins, confidences=conf, method='quantile') ## for L1-ECE\n",
    "        bins_list_umb.append(cur_bins_umb)\n",
    "        cur_bins = idx_bins(conf, cur_bins_umb)\n",
    "        bins_umb.append(cur_bins.numpy())\n",
    "        \n",
    "        # UMB on recalibration data\n",
    "        cal_indices = np.random.choice(X_test.shape[0], size=cal_size, replace=False)\n",
    "        x_cal, y_cal = X_test[cal_indices], y_test[cal_indices]\n",
    "        cal_preds = isnotnan(model.predict_proba(x_cal))\n",
    "        \n",
    "        ece_tr = compute_ece(cur_preds, mask=cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(cal_preds), labels_cal=torch.tensor(y_cal).long())\n",
    "        ece_tes = compute_ece(cur_preds, mask=1-cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True, cal_data=True, preds_cal=torch.tensor(cal_preds), labels_cal=torch.tensor(y_cal).long())\n",
    "        ece_gap_umb.append(np.abs(ece_tr - ece_tes))\n",
    "\n",
    "        # UMB on full training dataset (proposed recalibration)\n",
    "        ece_tr = compute_ece(cur_preds, mask=cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True)\n",
    "        ece_tes = compute_ece(cur_preds, mask=1-cur_mask, labels=y_train, n_bins=cur_bins_umb, recalibrate=True)\n",
    "        ece_gap_umb_proposed.append(np.abs(ece_tr - ece_tes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e6d32113-8fad-424d-88c1-14ffd10dbadf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([9, 9, 9, 9, 9, 8, 8, 8, 8, 9, 9, 9, 9, 9, 8, 9, 8, 8, 9, 9, 9, 9, 9, 9,\n",
       "        8, 9, 9, 8, 9, 9, 8, 9, 9, 8, 8, 9, 9, 9, 9, 9, 9, 8, 9, 9, 9, 8, 8, 9,\n",
       "        8, 9, 9, 8, 9, 9, 7, 9, 9, 8, 8, 9, 8, 9, 8, 9, 9, 9, 8, 8, 9, 7, 9, 9,\n",
       "        9, 8, 9, 8, 9, 9, 8, 9, 9, 9, 9, 9, 8, 9, 8, 8, 9, 8, 9, 8, 8, 9, 9, 9,\n",
       "        9, 8, 8, 9, 8, 8, 9, 9, 9, 8, 8, 9, 8, 9, 9, 9, 9, 9, 9, 9, 9, 8, 9, 9,\n",
       "        8, 9, 9, 9, 8, 9, 8, 8, 8, 8, 8, 9, 8, 9, 8, 8, 9, 9, 9, 9, 8, 9, 9, 9,\n",
       "        9, 9, 9, 8, 9, 9, 8, 9, 8, 8, 9, 8, 8, 9, 7, 8, 9, 9, 8, 9, 8, 9, 9, 8,\n",
       "        9, 8, 8, 8, 8, 8, 8, 8, 8, 9, 8, 9, 9, 8, 8, 9, 8, 9, 8, 9, 9, 9, 9, 8,\n",
       "        9, 9, 9, 9, 8, 8, 8, 9])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(temp.binning)\n",
    "idx_bins(torch.tensor(cur_preds[train_indices]).max(1).values, torch.tensor(temp.binning))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9e2525bf-0f7e-4406-919f-941dec8ad789",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.009573755756835453 6.459886209081922e-10\n",
      "0.04374423210565001 9.144423531196001e-10\n",
      "0.024687134942394915 5.180404206865345e-10\n",
      "0.021870277801086198 4.26898311397451e-10\n"
     ]
    }
   ],
   "source": [
    "print(np.array(ece_gap_umb[0:40]).mean(), np.array(ece_gap_umb_proposed[0:40]).mean())\n",
    "print(np.array(ece_gap_umb[40:80]).mean(), np.array(ece_gap_umb_proposed[40:80]).mean())\n",
    "print(np.array(ece_gap_umb[80:120]).mean(), np.array(ece_gap_umb_proposed[80:120]).mean())\n",
    "print(np.array(ece_gap_umb[120:160]).mean(), np.array(ece_gap_umb_proposed[120:160]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "33972720-b0df-40e6-848c-e14a7d7af3ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.009573755756835453 6.459886209081922e-10\n",
      "0.04374423210565001 9.144423531196001e-10\n",
      "0.024687134942394915 5.180404206865345e-10\n",
      "0.021870277801086198 4.26898311397451e-10\n"
     ]
    }
   ],
   "source": [
    "print(np.array(ece_gap_umb[0:40]).mean(), np.array(ece_gap_umb_proposed[0:40]).mean())\n",
    "print(np.array(ece_gap_umb[40:80]).mean(), np.array(ece_gap_umb_proposed[40:80]).mean())\n",
    "print(np.array(ece_gap_umb[80:120]).mean(), np.array(ece_gap_umb_proposed[80:120]).mean())\n",
    "print(np.array(ece_gap_umb[120:160]).mean(), np.array(ece_gap_umb_proposed[120:160]).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "23c290ff-6886-47f8-aa19-aa0aea1522e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.        , 0.        ],\n",
       "       [1.        , 0.        ],\n",
       "       [1.        , 0.        ],\n",
       "       ...,\n",
       "       [0.57428571, 0.42571429],\n",
       "       [1.        , 0.        ],\n",
       "       [1.        , 0.        ]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp = HistogramBinning(mode='equal_freq', n_bins=10)\n",
    "temp.fit(cur_preds[train_indices], y[train_indices])\n",
    "#cur_preds = temp.predict_proba(cur_preds)\n",
    "temp.predict_proba(cur_preds[2*np.arange(n) + (1-cur_mask)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d24e7835-8e75-4212-a265-93baed37e105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.022993719472756108"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#print(cur_preds.shape, y.shape)\n",
    "from KDEpy import FFTKDE\n",
    "y_cal = np.eye(len(np.unique(y)))[y]\n",
    "ece_kde_binary(cur_preds, y_cal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "8e74bd75-a454-479b-bad5-c3bdeee26883",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 0, 0, 0])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([np.where(r==1)[0][0] for r in y_cal])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b9d31923-cd46-42de-857f-0df9cc8a6ab7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mirror_1d(d, xmin=None, xmax=None):\n",
    "    \"\"\"If necessary apply reflecting boundary conditions.\"\"\"\n",
    "    if xmin is not None and xmax is not None:\n",
    "        xmed = (xmin+xmax)/2\n",
    "        return np.concatenate(((2*xmin-d[d < xmed]).reshape(-1,1), d, (2*xmax-d[d >= xmed]).reshape(-1,1)))\n",
    "    elif xmin is not None:\n",
    "        return np.concatenate((2*xmin-d, d))\n",
    "    elif xmax is not None:\n",
    "        return np.concatenate((d, 2*xmax-d))\n",
    "    else:\n",
    "        return d\n",
    "\n",
    "def ece_kde_binary(p,label,p_int=None,order=1):\n",
    "\n",
    "    # points from numerical integration\n",
    "    if p_int is None:\n",
    "        p_int = np.copy(p)\n",
    "\n",
    "    p = np.clip(p,1e-256,1-1e-256)\n",
    "    p_int = np.clip(p_int,1e-256,1-1e-256)\n",
    "    \n",
    "    \n",
    "    x_int = np.linspace(-0.6, 1.6, num=2**14)\n",
    "    \n",
    "    \n",
    "    N = p.shape[0]\n",
    "\n",
    "    # this is needed to convert labels from one-hot to conventional form\n",
    "    label_index = np.array([np.where(r==1)[0][0] for r in label])\n",
    "    with torch.no_grad():\n",
    "        if p.shape[1] !=2:\n",
    "            p_new = torch.from_numpy(p)\n",
    "            p_b = torch.zeros(N,1)\n",
    "            label_binary = np.zeros((N,1))\n",
    "            for i in range(N):\n",
    "                pred_label = int(torch.argmax(p_new[i]).numpy())\n",
    "                if pred_label == label_index[i]:\n",
    "                    label_binary[i] = 1\n",
    "                p_b[i] = p_new[i,pred_label]/torch.sum(p_new[i,:])  \n",
    "        else:\n",
    "            p_b = torch.from_numpy((p/np.sum(p,1)[:,None])[:,1])\n",
    "            label_binary = label_index\n",
    "                \n",
    "    method = 'triweight'\n",
    "    \n",
    "    dconf_1 = (p_b[np.where(label_binary==1)].reshape(-1,1)).numpy()\n",
    "    kbw = np.std(p_b.numpy())*(N*2)**-0.2\n",
    "    kbw = np.std(dconf_1)*(N*2)**-0.2\n",
    "    # Mirror the data about the domain boundary\n",
    "    low_bound = 0.0\n",
    "    up_bound = 1.0\n",
    "    dconf_1m = mirror_1d(dconf_1,low_bound,up_bound)\n",
    "    # Compute KDE using the bandwidth found, and twice as many grid points\n",
    "    pp1 = FFTKDE(bw=kbw, kernel=method).fit(dconf_1m).evaluate(x_int)\n",
    "    pp1[x_int<=low_bound] = 0  # Set the KDE to zero outside of the domain\n",
    "    pp1[x_int>=up_bound] = 0  # Set the KDE to zero outside of the domain\n",
    "    pp1 = pp1 * 2  # Double the y-values to get integral of ~1\n",
    "    \n",
    "    \n",
    "    p_int = p_int/np.sum(p_int,1)[:,None]\n",
    "    N1 = p_int.shape[0]\n",
    "    with torch.no_grad():\n",
    "        p_new = torch.from_numpy(p_int)\n",
    "        pred_b_int = np.zeros((N1,1))\n",
    "        if p_int.shape[1]!=2:\n",
    "            for i in range(N1):\n",
    "                pred_label = int(torch.argmax(p_new[i]).numpy())\n",
    "                pred_b_int[i] = p_int[i,pred_label]\n",
    "        else:\n",
    "            for i in range(N1):\n",
    "                pred_b_int[i] = p_int[i,1]\n",
    "\n",
    "    low_bound = 0.0\n",
    "    up_bound = 1.0\n",
    "    pred_b_intm = mirror_1d(pred_b_int,low_bound,up_bound)\n",
    "    # Compute KDE using the bandwidth found, and twice as many grid points\n",
    "    pp2 = FFTKDE(bw=kbw, kernel=method).fit(pred_b_intm).evaluate(x_int)\n",
    "    pp2[x_int<=low_bound] = 0  # Set the KDE to zero outside of the domain\n",
    "    pp2[x_int>=up_bound] = 0  # Set the KDE to zero outside of the domain\n",
    "    pp2 = pp2 * 2  # Double the y-values to get integral of ~1\n",
    "\n",
    "    \n",
    "    if p.shape[1] !=2: # top label (confidence)\n",
    "        perc = np.mean(label_binary)\n",
    "    else: # or joint calibration for binary cases\n",
    "        perc = np.mean(label_index)\n",
    "            \n",
    "    integral = np.zeros(x_int.shape)\n",
    "    reliability= np.zeros(x_int.shape)\n",
    "    for i in range(x_int.shape[0]):\n",
    "        conf = x_int[i]\n",
    "        if np.max([pp1[np.abs(x_int-conf).argmin()],pp2[np.abs(x_int-conf).argmin()]])>1e-6:\n",
    "            accu = np.min([perc*pp1[np.abs(x_int-conf).argmin()]/pp2[np.abs(x_int-conf).argmin()],1.0])\n",
    "            if np.isnan(accu)==False:\n",
    "                integral[i] = np.abs(conf-accu)**order*pp2[i]  \n",
    "                reliability[i] = accu\n",
    "        else:\n",
    "            if i>1:\n",
    "                integral[i] = integral[i-1]\n",
    "\n",
    "    ind = np.where((x_int >= 0.0) & (x_int <= 1.0))\n",
    "    return np.trapz(integral[ind],x_int[ind])/np.trapz(pp2[ind],x_int[ind])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "8f7cbfb6-288e-4e86-bf9d-25ecbbc531dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.bound_utils import estimate_fcmi_bound_umb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "bb699e3e-8f97-4ee6-8a2d-c1c614e7efe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bins_eval = n_bins\n",
    "\n",
    "#cal_bound_umb_l1, mis_list = estimate_fcmi_bound_umb(masks=masks[0:40], preds=preds[0:40], labels=labels[0:40], bins=bins_umb[0:40], num_examples=n_sample[0], num_classes=2, n_bins=n_bins_eval, bins_list=bins_list_umb[0:40],\n",
    "#                                       norm='l1', loss='reuse', recalibration=True, verbose=False, return_list_of_mis=True)\n",
    "cal_bound_umb_l1, mis_list = estimate_fcmi_bound_umb(masks=masks[120:160], preds=preds[120:160], labels=labels[120:160], bins=bins_umb[120:160], num_examples=n_sample[3], num_classes=2, n_bins=n_bins_eval, bins_list=bins_list_umb[120:160],\n",
    "                                       norm='l1', loss='reuse', recalibration=True, verbose=False, return_list_of_mis=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "fb494b93-800c-4487-8eb7-b2ccc2fe0bf5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def calc_bound_cal(mi, n, b):\n",
    "    return np.sqrt(2*(mi + b*np.log(2)) / n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "76bc2c96-2d62-4ed0-93ef-b3d81eacab9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.06137769431336504\n",
      "3.865926738371572e-05\n"
     ]
    }
   ],
   "source": [
    "print(np.array([calc_bound_cal(mis_list[i], n_sample[3], b=int(n_sample[3] ** (1/3))) for i in range(len(mis_list))]).mean())\n",
    "print(np.array([calc_bound_cal(mis_list[i], n_sample[3], b=int(n_sample[3] ** (1/3))) for i in range(len(mis_list))]).std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "440ba5d7-0cf6-42b8-8758-4df0829e354d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.11336427694257725"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "denom = np.floor(1000 / int(1000**(1/3)))\n",
    "2/denom + 1/np.sqrt(denom-1)\n",
    "#(2/np.floor(1000/(1000**(1/3)))) + 1/np.sqrt(np.floor(1000/(1000**(1/3))) - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec2f68c0-c015-4cab-a08f-2189edcba10e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e43fdff8-af4e-483d-8cbc-f716321cd451",
   "metadata": {},
   "source": [
    "## Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "50ce0103-6cdd-4549-b7a1-214df08dcc1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d3e06fa7-b348-453d-9f4a-4a729df3ad54",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_name = \"kitti\"\n",
    "results_dir = \"results\"\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_xgboost.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results_xg = pickle.load(f)\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_randomforest.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results_rf = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "001d7a1b-879f-4d9b-a782-00e01c81dc42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.0056793294, 0.0022095845] [2.2226581e-07, 1.1884055e-07]\n",
      "[0.012477067320102026, 0.0016375295579531075] [2.989543383820977e-10, 1.9910538755847342e-10]\n"
     ]
    }
   ],
   "source": [
    "print(results_xg['ece_gap'][3], results_xg['ece_gap_proposed'][3])\n",
    "print(results_rf['ece_gap'][3], results_rf['ece_gap_proposed'][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "897407fc-d644-42a8-ab28-f910e90065db",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.061369867915001694 2.9653493706447115e-05\n",
      "0.06136981171217612 2.9634696146157814e-05\n"
     ]
    }
   ],
   "source": [
    "print(results_xg['bound_values'][3].mean(), results_xg['bound_values'][3].std())\n",
    "print(results_rf['bound_values'][3].mean(), results_rf['bound_values'][3].std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e61004a-79af-4e2a-9d3c-5232ec297a83",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_name = \"pcam\"\n",
    "results_dir = \"results\"\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_xgboost.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results_xg = pickle.load(f)\n",
    "\n",
    "results_file_path = os.path.join(results_dir, exp_name, 'results_randomforest.pkl')\n",
    "with open(results_file_path, 'rb') as f:\n",
    "    results_rf = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1d02fc0f-d13f-41ac-be29-235378ce217e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.06032624, 0.0029271988] [2.6562195e-06, 1.493409e-06]\n",
      "[0.10390117613397538, 0.003991066741647776] [2.3968316237841123e-09, 1.3871859800108341e-09]\n"
     ]
    }
   ],
   "source": [
    "print(results_xg['ece_gap'][3], results_xg['ece_gap_proposed'][3])\n",
    "print(results_rf['ece_gap'][3], results_rf['ece_gap_proposed'][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "41fbbeec-6001-48bd-b953-01cf18bde9cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.05397570100143702 2.1518297961225384e-05\n",
      "0.05397573207243811 2.149118248128494e-05\n"
     ]
    }
   ],
   "source": [
    "print(results_xg['bound_values'][3].mean(), results_xg['bound_values'][3].std())\n",
    "print(results_rf['bound_values'][3].mean(), results_rf['bound_values'][3].std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "bb577056-2371-452a-862e-308842e893b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(results_xg['bound_values'][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "eab8420a-116a-4bc2-96d3-c009fde7a36f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.6701, -0.2976],\n",
       "        [ 0.4078,  0.4191],\n",
       "        [ 0.0099, -1.2178],\n",
       "        [ 0.0203,  1.4816],\n",
       "        [ 2.2968, -0.5772],\n",
       "        [-0.1616, -0.1067],\n",
       "        [-0.2254,  1.6605],\n",
       "        [ 1.9329,  0.2608],\n",
       "        [ 1.3131, -2.1901],\n",
       "        [ 0.0248, -1.3805]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3a1a28c4-fd62-4d5e-a565-c7887cc8cd75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.6701, -0.2976])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "b750ad9f-4b1c-4a2d-a1f6-18d40be41840",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.019333040460546055"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#print(cur_preds.shape, y.shape)\n",
    "from KDEpy import FFTKDE\n",
    "y_cal = np.eye(len(np.unique(y)))[y]\n",
    "ece_kde_binary(cur_preds, y_cal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "fe32c034-8aee-447b-b9dc-4f2bc10aa8b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0305)"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#device='cpu'\n",
    "#get_bandwidth(torch.tensor(cur_preds),\"cpu\")\n",
    "#len(torch.tensor(cur_preds))\n",
    "#get_ece_kde(torch.tensor(cur_preds), torch.tensor(y).long(), bandwidth=0.001, p=1, mc_type='canonical', device='cpu')\n",
    "get_ece_kde(torch.tensor(cur_preds), torch.tensor(y).long(), bandwidth=0.001, p=1, mc_type='marginal', device='cpu')\n",
    "#get_ece_kde(torch.tensor(cur_preds), torch.tensor(y).long(), bandwidth=0.001, p=1, mc_type='top_label', device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "d78e93f0-7c93-42fa-8dec-ee8f74123953",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "def get_bandwidth(f, device):\n",
    "    \"\"\"\n",
    "    Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE).\n",
    "\n",
    "    :param f: The vector containing the probability scores, shape [num_samples, num_classes]\n",
    "    :param device: The device type: 'cpu' or 'cuda'\n",
    "\n",
    "    :return: The bandwidth of the kernel\n",
    "    \"\"\"\n",
    "    bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5)))\n",
    "    max_b = -1\n",
    "    max_l = 0\n",
    "    n = len(f)\n",
    "    for b in bandwidths:\n",
    "        log_kern = get_kernel(f, b, device)\n",
    "        #log_fhat = torch.logsumexp(log_kern, 1) - torch.log(n-1)\n",
    "        log_fhat = torch.logsumexp(log_kern, 1) - np.log(n-1)\n",
    "        l = torch.sum(log_fhat)\n",
    "        if l > max_l:\n",
    "            max_l = l\n",
    "            max_b = b\n",
    "\n",
    "    return max_b\n",
    "\n",
    "\n",
    "def get_ece_kde(f, y, bandwidth, p, mc_type, device):\n",
    "    \"\"\"\n",
    "    Calculate an estimate of Lp calibration error.\n",
    "\n",
    "    :param f: The vector containing the probability scores, shape [num_samples, num_classes]\n",
    "    :param y: The vector containing the labels, shape [num_samples]\n",
    "    :param bandwidth: The bandwidth of the kernel\n",
    "    :param p: The p-norm. Typically, p=1 or p=2\n",
    "    :param mc_type: The type of multiclass calibration: canonical, marginal or top_label\n",
    "    :param device: The device type: 'cpu' or 'cuda'\n",
    "\n",
    "    :return: An estimate of Lp calibration error\n",
    "    \"\"\"\n",
    "    check_input(f, bandwidth, mc_type)\n",
    "    if f.shape[1] == 1:\n",
    "        return 2 * get_ratio_binary(f, y, bandwidth, p, device)\n",
    "    else:\n",
    "        if mc_type == 'canonical':\n",
    "            return get_ratio_canonical(f, y, bandwidth, p, device)\n",
    "        elif mc_type == 'marginal':\n",
    "            return get_ratio_marginal_vect(f, y, bandwidth, p, device)\n",
    "        elif mc_type == 'top_label':\n",
    "            return get_ratio_toplabel(f, y, bandwidth, p, device)\n",
    "\n",
    "\n",
    "def get_ratio_binary(f, y, bandwidth, p, device):\n",
    "    assert f.shape[1] == 1\n",
    "\n",
    "    log_kern = get_kernel(f, bandwidth, device)\n",
    "\n",
    "    return get_kde_for_ece(f, y, log_kern, p)\n",
    "\n",
    "\n",
    "def get_ratio_canonical(f, y, bandwidth, p, device):\n",
    "    if f.shape[1] > 60:\n",
    "        # Slower but more numerically stable implementation for larger number of classes\n",
    "        return get_ratio_canonical_log(f, y, bandwidth, p, device)\n",
    "\n",
    "    log_kern = get_kernel(f, bandwidth, device)\n",
    "    kern = torch.exp(log_kern)\n",
    "\n",
    "    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)\n",
    "    kern_y = torch.matmul(kern, y_onehot)\n",
    "    den = torch.sum(kern, dim=1)\n",
    "    # to avoid division by 0\n",
    "    den = torch.clamp(den, min=1e-10)\n",
    "\n",
    "    ratio = kern_y / den.unsqueeze(-1)\n",
    "    ratio = torch.sum(torch.abs(ratio - f)**p, dim=1)\n",
    "\n",
    "    return torch.mean(ratio)\n",
    "\n",
    "\n",
    "# Note for training: Make sure there are at least two examples for every class present in the batch, otherwise\n",
    "# LogsumexpBackward returns nans.\n",
    "def get_ratio_canonical_log(f, y, bandwidth, p, device='cpu'):\n",
    "    log_kern = get_kernel(f, bandwidth, device)\n",
    "    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)\n",
    "    log_y = torch.log(y_onehot)\n",
    "    log_den = torch.logsumexp(log_kern, dim=1)\n",
    "    final_ratio = 0\n",
    "    for k in range(f.shape[1]):\n",
    "        log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0))\n",
    "        log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den\n",
    "        inner_ratio = torch.exp(log_inner_ratio)\n",
    "        inner_diff = torch.abs(inner_ratio - f[:, k])**p\n",
    "        final_ratio += inner_diff\n",
    "\n",
    "    return torch.mean(final_ratio)\n",
    "\n",
    "\n",
    "def get_ratio_marginal_vect(f, y, bandwidth, p, device):\n",
    "    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)\n",
    "    log_kern_vect = beta_kernel(f, f, bandwidth).squeeze()\n",
    "    log_kern_diag = torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)\n",
    "    # Multiclass case\n",
    "    log_kern_diag_repeated = f.shape[1] * [log_kern_diag]\n",
    "    log_kern_diag_repeated = torch.stack(log_kern_diag_repeated, dim=2)\n",
    "    log_kern_vect = log_kern_vect + log_kern_diag_repeated\n",
    "\n",
    "    return get_kde_for_ece_vect(f, y_onehot, log_kern_vect, p)\n",
    "\n",
    "\n",
    "def get_ratio_toplabel(f, y, bandwidth, p, device):\n",
    "    f_max, indices = torch.max(f, 1)\n",
    "    f_max = f_max.unsqueeze(-1)\n",
    "    y_max = (y == indices).to(torch.int)\n",
    "\n",
    "    return get_ratio_binary(f_max, y_max, bandwidth, p, device)\n",
    "\n",
    "\n",
    "def get_kde_for_ece_vect(f, y, log_kern, p):\n",
    "    log_kern_y = log_kern * y\n",
    "    # Trick: -inf instead of 0 in log space\n",
    "    log_kern_y[log_kern_y == 0] = torch.finfo(torch.float).min\n",
    "\n",
    "    log_num = torch.logsumexp(log_kern_y, dim=1)\n",
    "    log_den = torch.logsumexp(log_kern, dim=1)\n",
    "\n",
    "    log_ratio = log_num - log_den\n",
    "    ratio = torch.exp(log_ratio)\n",
    "    ratio = torch.abs(ratio - f)**p\n",
    "\n",
    "    return torch.sum(torch.mean(ratio, dim=0))\n",
    "\n",
    "\n",
    "def get_kde_for_ece(f, y, log_kern, p):\n",
    "    f = f.squeeze()\n",
    "    N = len(f)\n",
    "    # Select the entries where y = 1\n",
    "    idx = torch.where(y == 1)[0]\n",
    "    if not idx.numel():\n",
    "        return torch.sum((torch.abs(-f))**p) / N\n",
    "\n",
    "    if idx.numel() == 1:\n",
    "        # because of -inf in the vector\n",
    "        log_kern = torch.cat((log_kern[:idx], log_kern[idx+1:]))\n",
    "        f_one = f[idx]\n",
    "        f = torch.cat((f[:idx], f[idx+1:]))\n",
    "\n",
    "    log_kern_y = torch.index_select(log_kern, 1, idx)\n",
    "\n",
    "    log_num = torch.logsumexp(log_kern_y, dim=1)\n",
    "    log_den = torch.logsumexp(log_kern, dim=1)\n",
    "\n",
    "    log_ratio = log_num - log_den\n",
    "    ratio = torch.exp(log_ratio)\n",
    "    ratio = torch.abs(ratio - f)**p\n",
    "\n",
    "    if idx.numel() == 1:\n",
    "        return (ratio.sum() + f_one ** p)/N\n",
    "\n",
    "    return torch.mean(ratio)\n",
    "\n",
    "\n",
    "def get_kernel(f, bandwidth, device):\n",
    "    # if num_classes == 1\n",
    "    if f.shape[1] == 1:\n",
    "        log_kern = beta_kernel(f, f, bandwidth).squeeze()\n",
    "    else:\n",
    "        log_kern = dirichlet_kernel(f, bandwidth).squeeze()\n",
    "    # Trick: -inf on the diagonal\n",
    "    return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)\n",
    "\n",
    "\n",
    "def beta_kernel(z, zi, bandwidth=0.1):\n",
    "    p = zi / bandwidth + 1\n",
    "    q = (1-zi) / bandwidth + 1\n",
    "    z = z.unsqueeze(-2)\n",
    "\n",
    "    log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)\n",
    "    log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)\n",
    "    log_beta_pdf = log_num - log_beta\n",
    "\n",
    "    return log_beta_pdf\n",
    "\n",
    "\n",
    "def dirichlet_kernel(z, bandwidth=0.1):\n",
    "    alphas = z / bandwidth + 1\n",
    "\n",
    "    log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))\n",
    "    log_num = torch.matmul(torch.log(z), (alphas-1).T)\n",
    "    log_dir_pdf = log_num - log_beta\n",
    "\n",
    "    return log_dir_pdf\n",
    "\n",
    "\n",
    "def check_input(f, bandwidth, mc_type):\n",
    "    assert not isnan(f)\n",
    "    assert len(f.shape) == 2\n",
    "    assert bandwidth > 0\n",
    "    assert torch.min(f) >= 0\n",
    "    assert torch.max(f) <= 1\n",
    "\n",
    "\n",
    "def isnan(a):\n",
    "    return torch.any(torch.isnan(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fa0a87c-cf26-4baa-94f0-2809ef09837b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
