{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9c491260",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-02 04:51:30.139461: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-10-02 04:51:30.139545: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-10-02 04:51:30.140879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-10-02 04:51:30.152986: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "# import tensorflow_probability as tfp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "import src.utils as utils\n",
    "import src.metrics as metrics\n",
    "import src.methods as methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2130cd18",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'stl10',\n",
    "        'model' : 'vgg16',\n",
    "        'batch_size' : 512,\n",
    "        'optimizer' : 'Adam',\n",
    "        'learning_rate' : 0.001,\n",
    "        'max_epoch' : 300,\n",
    "        'patience' : 10,}    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f23d119",
   "metadata": {},
   "source": [
    "### Softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "30594d97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Classification Error: 0.562\n",
      "Run: 2\n",
      "Classification Error: 0.675\n",
      "Run: 3\n",
      "Classification Error: 0.638\n",
      "Run: 4\n",
      "Classification Error: 0.638\n",
      "Run: 5\n",
      "Classification Error: 0.475\n",
      "---------------------\n",
      "Average Classification Error: 0.60, std: 0.07\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "ece_list = []\n",
    "nll_list = []\n",
    "bs_list = []\n",
    "class_error_list = []\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([utils.softmax(x) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    opt_temp = np.load(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp.npy')\n",
    "    softmax_test = methods.max_softmax_prob(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "    \n",
    "    ece = metrics.compute_ece(softmax_test, true_y, pred_y, n_bins=10)\n",
    "    ece_list.append(ece)\n",
    "    nll = metrics.compute_nll(preds, true_y, n_classes)\n",
    "    nll_list.append(nll)\n",
    "    bs = metrics.compute_brier_score(preds, true_y, n_classes)\n",
    "    bs_list.append(bs)\n",
    "    class_error = metrics.compute_classification_error(preds, true_y, k=5)\n",
    "    class_error_list.append(class_error)\n",
    "    \n",
    "    print(f'ECE: {ece:.3f}')\n",
    "    print(f'NLL: {nll:.3f}')\n",
    "    print(f'Brier Score: {bs:.3f}')\n",
    "    print(f'Classification Error: {class_error*100:.3f}')\n",
    "    \n",
    "print('---------------------')\n",
    "print(f'Average ECE: {np.mean(ece_list):.2f}, std: {np.std(ece_list):.2f}')\n",
    "print(f'Average NLL: {np.mean(nll_list):.3f}, std: {np.std(nll_list):.3f}')\n",
    "print(f'Average Brier Score: {np.mean(bs_list):.3f}, std: {np.std(bs_list):.3f}')\n",
    "print(f'Average Classification Error: {np.mean(class_error_list)*100:.2f}, std: {np.std(class_error_list)*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c651b776",
   "metadata": {},
   "source": [
    "### PMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd8e2f08",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "ECE: 12.765\n",
      "NLL: 1.543\n",
      "Brier Score: 0.273\n",
      "Classification Error: 14.450\n",
      "Run: 2\n",
      "ECE: 11.388\n",
      "NLL: 1.263\n",
      "Brier Score: 0.246\n",
      "Classification Error: 13.320\n",
      "Run: 3\n",
      "ECE: 12.104\n",
      "NLL: 1.502\n",
      "Brier Score: 0.257\n",
      "Classification Error: 13.650\n",
      "Run: 4\n",
      "ECE: 12.314\n",
      "NLL: 1.570\n",
      "Brier Score: 0.259\n",
      "Classification Error: 13.640\n",
      "Run: 5\n",
      "ECE: 12.671\n",
      "NLL: 1.518\n",
      "Brier Score: 0.267\n",
      "Classification Error: 14.120\n",
      "---------------------\n",
      "Average ECE: 12.25, std: 0.49\n",
      "Average NLL: 1.479, std: 0.111\n",
      "Average Brier Score: 0.260, std: 0.009\n",
      "Average Classification Error: 13.84, std: 0.40\n"
     ]
    }
   ],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'variational_f_js'\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "ece_list = []\n",
    "nll_list = []\n",
    "bs_list = []\n",
    "class_error_list = []\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    \n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    pmi_class = np.load(f'{exp_name}/pmi_output_class_test.npy')\n",
    "    temp = np.load(f'{exp_name}/pmi_opt_temp.npy')\n",
    "    pmi_class = np.array([utils.softmax(x/temp) for x in pmi_class])\n",
    "    pmi = np.array([pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class, pred_y)])\n",
    "    \n",
    "    ece = metrics.compute_ece(pmi, true_y, pred_y, n_bins=10)\n",
    "    ece_list.append(ece)\n",
    "    nll = metrics.compute_nll(pmi_class, true_y, n_classes)\n",
    "    nll_list.append(nll)\n",
    "    bs = metrics.compute_brier_score(pmi_class, true_y, n_classes)\n",
    "    bs_list.append(bs)\n",
    "    class_error = metrics.compute_classification_error(pmi_class, true_y)\n",
    "    class_error_list.append(class_error)\n",
    "    \n",
    "    print(f'ECE: {ece:.3f}')\n",
    "    print(f'NLL: {nll:.3f}')\n",
    "    print(f'Brier Score: {bs:.3f}')\n",
    "    print(f'Classification Error: {class_error*100:.3f}')\n",
    "    \n",
    "print('---------------------')\n",
    "print(f'Average ECE: {np.mean(ece_list):.2f}, std: {np.std(ece_list):.2f}')\n",
    "print(f'Average NLL: {np.mean(nll_list):.3f}, std: {np.std(nll_list):.3f}')\n",
    "print(f'Average Brier Score: {np.mean(bs_list):.3f}, std: {np.std(bs_list):.3f}')\n",
    "print(f'Average Classification Error: {np.mean(class_error_list)*100:.2f}, std: {np.std(class_error_list)*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e55b4f7b",
   "metadata": {},
   "source": [
    "### PVI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4e29e056",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Classification Error: 0.625\n",
      "Run: 2\n",
      "Classification Error: 0.638\n",
      "Run: 3\n",
      "Classification Error: 0.575\n",
      "Run: 4\n",
      "Classification Error: 0.500\n",
      "Run: 5\n",
      "Classification Error: 0.687\n",
      "---------------------\n",
      "Average Classification Error: 0.60, std: 0.06\n"
     ]
    }
   ],
   "source": [
    "estimator = 'training_from_scratch'\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "ece_list = []\n",
    "nll_list = []\n",
    "bs_list = []\n",
    "class_error_list = []\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    \n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "    pvi_class = np.load(f'{exp_name}/pvi_class_test.npy')\n",
    "    temp = np.load(f'{exp_name}/pvi_opt_temp.npy')\n",
    "    pvi_class = np.array([utils.softmax(x/temp) for x in pvi_class])\n",
    "    pvi = np.array([pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y)])\n",
    "    \n",
    "#     ece = metrics.compute_ece(pvi, true_y, pred_y, n_bins=10)\n",
    "#     ece_list.append(ece)\n",
    "#     nll = metrics.compute_nll(pvi_class, true_y, n_classes)\n",
    "#     nll_list.append(nll)\n",
    "#     bs = metrics.compute_brier_score(pvi_class, true_y, n_classes)\n",
    "#     bs_list.append(bs)\n",
    "    class_error = metrics.compute_classification_error(pvi_class, true_y, k=5)\n",
    "    class_error_list.append(class_error)\n",
    "    \n",
    "#     print(f'ECE: {ece:.3f}')\n",
    "#     print(f'NLL: {nll:.3f}')\n",
    "#     print(f'Brier Score: {bs:.3f}')\n",
    "    print(f'Classification Error: {class_error*100:.3f}')\n",
    "    \n",
    "print('---------------------')\n",
    "# print(f'Average ECE: {np.mean(ece_list):.2f}, std: {np.std(ece_list):.2f}')\n",
    "# print(f'Average NLL: {np.mean(nll_list):.3f}, std: {np.std(nll_list):.3f}')\n",
    "# print(f'Average Brier Score: {np.mean(bs_list):.3f}, std: {np.std(bs_list):.3f}')\n",
    "print(f'Average Classification Error: {np.mean(class_error_list)*100:.2f}, std: {np.std(class_error_list)*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb0f923",
   "metadata": {},
   "source": [
    "### PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4ce06b5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "ECE: 12.350\n",
      "NLL: 1.508\n",
      "Brier Score: 0.271\n",
      "Classification Error: 14.580\n",
      "Run: 2\n",
      "ECE: 9.546\n",
      "NLL: 0.841\n",
      "Brier Score: 0.230\n",
      "Classification Error: 13.300\n",
      "Run: 3\n",
      "ECE: 11.619\n",
      "NLL: 1.468\n",
      "Brier Score: 0.259\n",
      "Classification Error: 13.850\n",
      "Run: 4\n",
      "ECE: 8.936\n",
      "NLL: 0.703\n",
      "Brier Score: 0.231\n",
      "Classification Error: 14.070\n",
      "Run: 5\n",
      "ECE: 12.388\n",
      "NLL: 1.531\n",
      "Brier Score: 0.266\n",
      "Classification Error: 14.120\n",
      "---------------------\n",
      "Average ECE: 10.97, std: 1.45\n",
      "Average NLL: 1.210, std: 0.361\n",
      "Average Brier Score: 0.251, std: 0.018\n",
      "Average Classification Error: 13.98, std: 0.42\n"
     ]
    }
   ],
   "source": [
    "estimator = 'gaussian'\n",
    "n_projs = 500\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "ece_list = []\n",
    "nll_list = []\n",
    "bs_list = []\n",
    "class_error_list = []\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    \n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "    psi_class = np.load(f'{exp_name}/psi_output_class_{n_projs}_projs_test.npy')\n",
    "    temp = np.load(f'{exp_name}/psi_opt_temp.npy')\n",
    "    psi_class = np.array([utils.softmax(x/temp) for x in psi_class])\n",
    "    psi = np.array([psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y)])\n",
    "    \n",
    "    ece = metrics.compute_ece(psi, true_y, pred_y, n_bins=10)\n",
    "    ece_list.append(ece)\n",
    "    nll = metrics.compute_nll(psi_class, true_y, n_classes)\n",
    "    nll_list.append(nll)\n",
    "    bs = metrics.compute_brier_score(psi_class, true_y, n_classes)\n",
    "    bs_list.append(bs)\n",
    "    class_error = metrics.compute_classification_error(psi_class, true_y)\n",
    "    class_error_list.append(class_error)\n",
    "    \n",
    "    print(f'ECE: {ece:.3f}')\n",
    "    print(f'NLL: {nll:.3f}')\n",
    "    print(f'Brier Score: {bs:.3f}')\n",
    "    print(f'Classification Error: {class_error*100:.3f}')\n",
    "    \n",
    "print('---------------------')\n",
    "print(f'Average ECE: {np.mean(ece_list):.2f}, std: {np.std(ece_list):.2f}')\n",
    "print(f'Average NLL: {np.mean(nll_list):.3f}, std: {np.std(nll_list):.3f}')\n",
    "print(f'Average Brier Score: {np.mean(bs_list):.3f}, std: {np.std(bs_list):.3f}')\n",
    "print(f'Average Classification Error: {np.mean(class_error_list)*100:.2f}, std: {np.std(class_error_list)*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5243f756",
   "metadata": {},
   "source": [
    "### Other Benchmark Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95b41cbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method: softmax_margin\n",
      "ECE: 10.205\n",
      "ECE: 9.041\n",
      "ECE: 9.507\n",
      "ECE: 9.883\n",
      "ECE: 10.535\n",
      "---------------------\n",
      "Average ECE: 9.83, std: 0.52\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "methods_list = ['softmax_margin']\n",
    "\n",
    "for method in methods_list:\n",
    "    print(f'Method: {method}')\n",
    "    ece_list = []\n",
    "    for run in range(5):\n",
    "        tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "        model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "        preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "        true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "        pred_y = np.argmax(preds, axis=1)\n",
    "\n",
    "        opt_temp = np.load(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp.npy')\n",
    "        if method == 'softmax_margin':\n",
    "            conf_test = methods.softmax_margin(model, ds_test, opt_temp, cfg['batch_size']).numpy()\n",
    "        elif method == 'logits_margin':\n",
    "            conf_test = methods.logits_margin(model, ds_test, cfg['batch_size']).numpy()\n",
    "        elif method == 'negative_entropy':\n",
    "            conf_test = methods.negative_entropy(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "        elif method == 'negative_gini':\n",
    "            conf_test = methods.negative_gini(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "\n",
    "        ece = metrics.compute_ece(conf_test, true_y, pred_y, n_bins=10)\n",
    "        ece_list.append(ece)\n",
    "        print(f'ECE: {ece:.3f}')\n",
    "    print('---------------------')\n",
    "    print(f'Average ECE: {np.mean(ece_list):.2f}, std: {np.std(ece_list):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c049f08",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
