{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "sns.set_context(\"paper\")\n",
    "sns.set_style(\"ticks\")\n",
    "import tikzplotlib\n",
    "from pathlib import Path\n",
    "import sklearn.model_selection\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torchvision.models import (\n",
    "    resnet50, ResNet50_Weights,\n",
    "    vit_b_16, ViT_B_16_Weights,\n",
    ")\n",
    "from torchvision.datasets import ImageFolder\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "from calibrators import fit_scaling_model\n",
    "from utils import logits_labels_from_dataloader, LogitsDataset\n",
    "from netcal.presentation import ReliabilityDiagram\n",
    "\n",
    "n_bins = 15\n",
    "\n",
    "seed = 0\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "rng = np.random.default_rng(seed)\n",
    "\n",
    "path_results = os.path.dirname(os.getcwd()) + '/results/'\n",
    "\n",
    "\n",
    "path_dataset = Path('/imagenet')\n",
    "\n",
    "BATCH_SIZE = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = 'ViT-B/16'\n",
    "\n",
    "models_and_weights_torchvision = {\n",
    "    'ResNet-50': (resnet50, ResNet50_Weights.DEFAULT),\n",
    "    'ViT-B/16': (vit_b_16, ViT_B_16_Weights.DEFAULT), \n",
    "}\n",
    "\n",
    "architecture, weights = models_and_weights_torchvision[model_name]\n",
    "classifier = architecture(weights=weights).eval().cuda()\n",
    "transforms = weights.transforms()\n",
    "num_classes = 1000\n",
    "num_epochs = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "valid_size = 25000\n",
    "\n",
    "dataset_val = ImageFolder(path_dataset/'val', transform=transforms)\n",
    "dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=True)\n",
    "# separate into validation (calibration) and test sets\n",
    "test_indices, valid_indices = sklearn.model_selection.train_test_split(np.arange(len(dataset_val)),\n",
    "                                                                        train_size=len(dataset_val) - valid_size,\n",
    "                                                                        stratify=dataset_val.targets,\n",
    "                                                                        random_state=seed)\n",
    "valid_loader = DataLoader(dataset_val, pin_memory=True, batch_size=BATCH_SIZE,\n",
    "                        sampler=SubsetRandomSampler(valid_indices), num_workers=4)\n",
    "test_loader = DataLoader(dataset_val, pin_memory=True, batch_size=BATCH_SIZE,\n",
    "                        sampler=SubsetRandomSampler(test_indices), num_workers=4)\n",
    "\n",
    "# Create calibration data\n",
    "logits_val, labels_val = logits_labels_from_dataloader(classifier, valid_loader)\n",
    "logits_test, labels_test = logits_labels_from_dataloader(classifier, test_loader)\n",
    "dataset_logits_val = LogitsDataset(logits_val, labels_val)\n",
    "dataloader_logits_val = DataLoader(dataset_logits_val, batch_size=512)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reliability diagrams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# figure 1\n",
    "probs = torch.softmax(logits_test, axis=1).numpy()\n",
    "ground_truth = labels_test.numpy()\n",
    "\n",
    "diagram = ReliabilityDiagram(bins=15)\n",
    "diagram.plot(probs, ground_truth, tikz=True, filename=f\"ReliabilityDiagram_{model_name.replace('/', '')}_uncalibrated.tikz\", axis_height ='4cm', axis_width='8cm');\n",
    "\n",
    "# figure 2\n",
    "if model_name == 'ResNet-50':\n",
    "    model = fit_scaling_model('temperature', dataloader_logits_val, num_classes, binary_loss=False, regularization=False, num_epochs=num_epochs)\n",
    "    logits_scaled = model(logits_test.cuda()).detach().cpu()\n",
    "    probs_scaled_TS = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "\n",
    "    diagram = ReliabilityDiagram(bins=15)\n",
    "    diagram.plot(probs_scaled_TS, ground_truth, tikz=True, filename=f\"ReliabilityDiagram_{model_name}_TS.tikz\", axis_height ='4cm', axis_width='8cm')\n",
    "\n",
    "elif model_name == 'ViT-B/16':\n",
    "    model = fit_scaling_model('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=False, num_epochs=num_epochs)\n",
    "    logits_scaled = model(logits_test.cuda()).detach().cpu()\n",
    "    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "\n",
    "    diagram = ReliabilityDiagram(bins=15)\n",
    "    diagram.plot(probs_scaled_VS, ground_truth, tikz=True, filename=f\"ReliabilityDiagram_{model_name.replace('/', '')}_VS.tikz\", axis_height ='4cm', axis_width='8cm')\n",
    "\n",
    "# figure 3\n",
    "if model_name == 'ResNet-50':\n",
    "    model = fit_scaling_model('temperature', dataloader_logits_val, num_classes, binary_loss=True, regularization=False, num_epochs=num_epochs)\n",
    "    logits_scaled = model(logits_test.cuda()).detach().cpu()\n",
    "    probs_scaled_TStva = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "\n",
    "    diagram = ReliabilityDiagram(bins=15)\n",
    "    diagram.plot(probs_scaled_TStva, ground_truth, tikz=True, filename=f\"ReliabilityDiagram_{model_name}_TStva.tikz\", axis_height ='4cm', axis_width='8cm');\n",
    "\n",
    "elif model_name == 'ViT-B/16':\n",
    "    model = fit_scaling_model('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=True, num_epochs=num_epochs)\n",
    "    logits_scaled = model(logits_test.cuda()).detach().cpu()\n",
    "    probs_scaled_VSregtva = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "\n",
    "    diagram = ReliabilityDiagram(bins=15)\n",
    "    diagram.plot(probs_scaled_VSregtva, ground_truth, tikz=True, filename=f\"ReliabilityDiagram_{model_name.replace('/', '')}_VSregtva.tikz\", axis_height ='4cm', axis_width='8cm')\n",
    "\n",
    "\n",
    "# figure 4\n",
    "probs = torch.softmax(logits_val, axis=1)\n",
    "certainties, y_pred = probs.max(axis=1)\n",
    "correct = y_pred == labels_val\n",
    "model = HB_binary()\n",
    "model.fit(certainties.numpy(), correct.numpy())\n",
    "probs = torch.softmax(logits_test, axis=1)\n",
    "certainties_test, y_pred = probs.max(axis=1)\n",
    "certainties_scaled = model.predict_proba(certainties_test.cpu().numpy())\n",
    "correct = (y_pred == labels_test).numpy()\n",
    "\n",
    "diagram = ReliabilityDiagram(bins=15)\n",
    "diagram.plot(certainties_scaled, correct, tikz=True, filename=f\"ReliabilityDiagram_{model_name.replace('/', '')}_HBtva.tikz\", axis_height ='4cm', axis_width='8cm');"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from calibrators import Temperature, Vector, Dirichlet\n",
    "\n",
    "def compute_ECE(model):\n",
    "    with torch.no_grad():\n",
    "        logits_scaled = model(logits_val.cuda()).cpu()\n",
    "    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "    ece_train = ece(labels_val, probs_scaled_VS, num_bins=15)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        logits_scaled = model(logits_test.cuda()).cpu()\n",
    "    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()\n",
    "    ece_test = ece(labels_test, probs_scaled_VS, num_bins=15)\n",
    "    \n",
    "    return ece_train, ece_test\n",
    "\n",
    "\n",
    "def fit_scaling_model_log(method, dataloader_logits_calib, num_classes, binary_loss, regularization, temperature_ref=1, num_epochs=200):\n",
    "    \n",
    "    if method == 'temperature':\n",
    "        calibrator = Temperature().cuda()\n",
    "    elif method =='vector':\n",
    "        calibrator = Vector(num_classes, temperature_ref).cuda()\n",
    "    elif method == 'dirichlet':\n",
    "        calibrator = Dirichlet(num_classes, temperature_ref).cuda()\n",
    "    else:\n",
    "        raise ValueError('Unknown method')\n",
    "\n",
    "    optimizer = torch.optim.Adam(calibrator.parameters(), lr=0.001)\n",
    "\n",
    "    ECE_train = np.zeros(num_epochs)\n",
    "    ECE_test = np.zeros(num_epochs)\n",
    "    for epoch in range(num_epochs):\n",
    "        epoch_loss = 0\n",
    "        for x, y in dataloader_logits_calib:\n",
    "            optimizer.zero_grad()\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "            logits_scaled = calibrator(x)\n",
    "            if binary_loss:\n",
    "                probas = torch.softmax(logits_scaled, axis=1)\n",
    "                confidence, y_pred = torch.max(probas, axis=1)\n",
    "                correct = (y_pred == y).float()\n",
    "                loss = nn.functional.binary_cross_entropy(confidence, correct)\n",
    "            else:\n",
    "                loss = nn.functional.cross_entropy(logits_scaled, y)\n",
    "            if method == 'dirichlet':\n",
    "                loss += calibrator.off_diag_reg * calibrator.model[0].weight.clone().fill_diagonal_(0).square().sum()\n",
    "                loss += calibrator.bias_reg * calibrator.model[0].bias.square().sum()\n",
    "                if regularization:\n",
    "                    loss += calibrator.diag_reg * (torch.diagonal(calibrator.model[0].weight) - 1).square().mean()\n",
    "            elif method == 'vector' and regularization:\n",
    "                loss += calibrator.vec_reg * (calibrator.vec - 1).square().mean()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "        ECE_train[epoch], ECE_test[epoch] = compute_ECE(calibrator)\n",
    "    \n",
    "    return ECE_test\n",
    "\n",
    "ECE_test = {}\n",
    "ECE_test['TS_tva'] = fit_scaling_model_log('temperature', dataloader_logits_val, num_classes, binary_loss=True, regularization=False)\n",
    "ECE_test['VS'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=False)\n",
    "ECE_test['VS_reg'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=True)\n",
    "ECE_test['VS_tva'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=False)\n",
    "ECE_test['VS_reg_tva'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=True)\n",
    "\n",
    "epoch_start = 20\n",
    "plt.figure()\n",
    "plt.plot(np.arange(epoch_start, len(ECE_test['TS_tva'])), 100*ECE_test['TS_tva'][epoch_start:], c='k', ls=':', label=r'TS\\textsubscript{TvA}')\n",
    "plt.plot(np.arange(epoch_start, len(ECE_test['VS'])), 100*ECE_test['VS'][epoch_start:], c='C0', ls='--', label='VS')\n",
    "plt.plot(np.arange(epoch_start, len(ECE_test['VS_reg'])), 100*ECE_test['VS_reg'][epoch_start:], c='C0', ls='-', label=r'VS\\textsubscript{reg}')\n",
    "plt.plot(np.arange(epoch_start, len(ECE_test['VS_tva'])), 100*ECE_test['VS_tva'][epoch_start:], c='C1', ls='--', label=r'VS\\textsubscript{TvA}')\n",
    "plt.plot(np.arange(epoch_start, len(ECE_test['VS_reg_tva'])), 100*ECE_test['VS_reg_tva'][epoch_start:], c='C1', ls='-', label=r'VS\\textsubscript{reg\\_TvA}')\n",
    "plt.legend()\n",
    "plt.xlabel('epochs')\n",
    "plt.ylabel('ECE test [%]')\n",
    "tikzplotlib.save('regularization.tikz')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Histogram for classwise ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "probas = torch.softmax(logits_test, 1)\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,3))\n",
    "axes[0].set_ylabel('number of Samples (log scale)')\n",
    "for i in range(3):\n",
    "    random_class_idx = np.random.randint(1000)\n",
    "    axes[i].hist(probas[:, random_class_idx], log=True)\n",
    "    axes[i].set_xlabel(f'class probability for class {random_class_idx}')\n",
    "tikzplotlib.save('classwiseECE_histograms.tikz')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
