{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "054409ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "import src.utils as utils\n",
    "import src.metrics as metrics\n",
    "import src.methods as methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36549c69",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'cifar10',\n",
    "       'model' : 'resnet50',\n",
    "       'batch_size' : 128,\n",
    "       }\n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "true_y_val = np.argmax([y for x,y in ds_val], axis=1)\n",
    "true_y_test = np.argmax([y for x,y in ds_test], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5228f184",
   "metadata": {},
   "source": [
    "### Softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5952f2ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 85.06, std: 0.40\n",
      "AUPRC (success): 96.70, std: 0.08\n",
      "AUPRC (error): 47.99, std: 1.87\n",
      "FPR95: 63.00, std: 1.41\n",
      "AURC: 39.08, std: 1.06\n"
     ]
    }
   ],
   "source": [
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_fpr95 = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    pred_y_val = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_val = np.equal(true_y_val, pred_y_val).astype(int)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int)\n",
    "    \n",
    "    logits_val = model.predict(ds_val.batch(512), verbose=0)\n",
    "    opt_temp = utils.temp_scaling_aurc(logits_val, pred_y_val, true_label_val)\n",
    "    np.save(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp.npy', opt_temp)\n",
    "    softmax_test = methods.max_softmax_prob(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "\n",
    "    auroc = metrics.compute_auroc(true_label_test, softmax_test)\n",
    "    auprc_succ = metrics.compute_auprc_success(true_label_test, softmax_test)\n",
    "    auprc_error = metrics.compute_auprc_error(true_label_test, softmax_test)\n",
    "    fpr95 = metrics.compute_fpr95(true_label_test, softmax_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, softmax_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_auprc_succ.append(auprc_succ)\n",
    "    all_auprc_error.append(auprc_error)\n",
    "    all_fpr95.append(fpr95)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (success): {np.mean(all_auprc_succ, axis=0)*100:.2f}, std: {np.std(all_auprc_succ, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (error): {np.mean(all_auprc_error, axis=0)*100:.2f}, std: {np.std(all_auprc_error, axis=0)*100:.2f}')\n",
    "print(f'FPR95: {np.mean(all_fpr95, axis=0)*100:.2f}, std: {np.std(all_fpr95, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a37dfe28",
   "metadata": {},
   "source": [
    "### PMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe1bbfb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 91.96, std: 0.35\n",
      "AUPRC (success): 99.37, std: 0.01\n",
      "AUPRC (error): 42.48, std: 2.39\n",
      "FPR95: 47.12, std: 1.80\n",
      "AURC: 8.11, std: 0.09\n"
     ]
    }
   ],
   "source": [
    "estimator = 'separable_variational_f_js'\n",
    "\n",
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_fpr95 = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{estimator}'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    pred_y_val = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_val = np.equal(true_y_val, pred_y_val).astype(int)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int)\n",
    "    \n",
    "    pmi_class_val = np.load(f'{exp_name}/pmi_output_class_val.npy')\n",
    "#     opt_temp = utils.temp_scaling_aurc(pmi_class_val, pred_y_val, true_label_val)\n",
    "#     np.save(f'{exp_name}/pmi_opt_temp.npy', opt_temp)\n",
    "    pmi_class_test = np.load(f'{exp_name}/pmi_output_class_test.npy')\n",
    "#     pmi_class_test = np.array([utils.softmax(x/opt_temp) for x in pmi_class_test])\n",
    "    pmi_class_test = np.array([utils.softmax(x) for x in pmi_class_test])\n",
    "    pmi_test = np.array([pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class_test, pred_y_test)])\n",
    "\n",
    "    auroc = metrics.compute_auroc(true_label_test, pmi_test)\n",
    "    auprc_succ = metrics.compute_auprc_success(true_label_test, pmi_test)\n",
    "    auprc_error = metrics.compute_auprc_error(true_label_test, pmi_test)\n",
    "    fpr95 = metrics.compute_fpr95(true_label_test, pmi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, pmi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_auprc_succ.append(auprc_succ)\n",
    "    all_auprc_error.append(auprc_error)\n",
    "    all_fpr95.append(fpr95)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (success): {np.mean(all_auprc_succ, axis=0)*100:.2f}, std: {np.std(all_auprc_succ, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (error): {np.mean(all_auprc_error, axis=0)*100:.2f}, std: {np.std(all_auprc_error, axis=0)*100:.2f}')\n",
    "print(f'FPR95: {np.mean(all_fpr95, axis=0)*100:.2f}, std: {np.std(all_fpr95, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b586b6e",
   "metadata": {},
   "source": [
    "### PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "51c3a2ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 89.42, std: 0.42\n",
      "AUPRC (success): 99.18, std: 0.02\n",
      "AUPRC (error): 33.39, std: 2.78\n",
      "FPR95: 57.65, std: 3.65\n",
      "AURC: 9.99, std: 0.26\n"
     ]
    }
   ],
   "source": [
    "estimator = 'gaussian'\n",
    "n_projs = 500\n",
    "\n",
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_fpr95 = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    pred_y_val = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_val = np.equal(true_y_val, pred_y_val).astype(int)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int)\n",
    "    \n",
    "    psi_class_val = np.load(f'{exp_name}/psi_output_class_{n_projs}_projs_val.npy')\n",
    "#     opt_temp = utils.temp_scaling_aurc(psi_class_val, pred_y_val, true_label_val)\n",
    "#     np.save(f'{exp_name}/psi_opt_temp.npy', opt_temp)\n",
    "    psi_class_test = np.load(f'{exp_name}/psi_output_class_{n_projs}_projs_test.npy')\n",
    "#     psi_class = np.array([utils.softmax(x/opt_temp) for x in psi_class_test])\n",
    "    psi_class = np.array([utils.softmax(x) for x in psi_class_test])\n",
    "    psi_test = np.array([psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y_test)])\n",
    "\n",
    "    auroc = metrics.compute_auroc(true_label_test, psi_test)\n",
    "    auprc_succ = metrics.compute_auprc_success(true_label_test, psi_test)\n",
    "    auprc_error = metrics.compute_auprc_error(true_label_test, psi_test)\n",
    "    fpr95 = metrics.compute_fpr95(true_label_test, psi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, psi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_auprc_succ.append(auprc_succ)\n",
    "    all_auprc_error.append(auprc_error)\n",
    "    all_fpr95.append(fpr95)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (success): {np.mean(all_auprc_succ, axis=0)*100:.2f}, std: {np.std(all_auprc_succ, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (error): {np.mean(all_auprc_error, axis=0)*100:.2f}, std: {np.std(all_auprc_error, axis=0)*100:.2f}')\n",
    "print(f'FPR95: {np.mean(all_fpr95, axis=0)*100:.2f}, std: {np.std(all_fpr95, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98c274ee",
   "metadata": {},
   "source": [
    "### PVI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b88fc898",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 86.02, std: 0.91\n",
      "AUPRC (success): 96.77, std: 0.30\n",
      "AUPRC (error): 56.07, std: 3.21\n",
      "FPR95: 52.49, std: 3.78\n",
      "AURC: 38.47, std: 2.76\n"
     ]
    }
   ],
   "source": [
    "estimator = 'training_from_scratch'\n",
    "\n",
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_fpr95 = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    pred_y_val = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_val = np.equal(true_y_val, pred_y_val).astype(int)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int)\n",
    "    \n",
    "    pvi_class_val = np.load(f'{exp_name}/pvi_class_val.npy')\n",
    "#     opt_temp = utils.temp_scaling_aurc(pvi_class_val, pred_y_val, true_label_val)\n",
    "#     np.save(f'{exp_name}/pvi_opt_temp.npy', opt_temp)\n",
    "    pvi_class_test = np.load(f'{exp_name}/pvi_class_test.npy')\n",
    "#     pvi_class_test = np.array([utils.softmax(x/opt_temp) for x in pvi_class_test])\n",
    "    pvi_class_test = np.array([utils.softmax(x) for x in pvi_class_test])\n",
    "    pvi_test = np.array([pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class_test, pred_y_test)])\n",
    "\n",
    "    auroc = metrics.compute_auroc(true_label_test, pvi_test)\n",
    "    auprc_succ = metrics.compute_auprc_success(true_label_test, pvi_test)\n",
    "    auprc_error = metrics.compute_auprc_error(true_label_test, pvi_test)\n",
    "    fpr95 = metrics.compute_fpr95(true_label_test, pvi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, pvi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_auprc_succ.append(auprc_succ)\n",
    "    all_auprc_error.append(auprc_error)\n",
    "    all_fpr95.append(fpr95)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (success): {np.mean(all_auprc_succ, axis=0)*100:.2f}, std: {np.std(all_auprc_succ, axis=0)*100:.2f}')\n",
    "print(f'AUPRC (error): {np.mean(all_auprc_error, axis=0)*100:.2f}, std: {np.std(all_auprc_error, axis=0)*100:.2f}')\n",
    "print(f'FPR95: {np.mean(all_fpr95, axis=0)*100:.2f}, std: {np.std(all_fpr95, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c49ca22e",
   "metadata": {},
   "source": [
    "### Other Benchmark Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "37ecab7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method: softmax_margin\n",
      "AUROC: 85.14, std: 0.38\n",
      "AUPRC (success): 96.75, std: 0.07\n",
      "AUPRC (error): 47.38, std: 1.78\n",
      "FPR95: 63.06, std: 1.49\n",
      "AURC: 38.63, std: 0.98\n",
      "Method: max_logits\n",
      "AUROC: 79.22, std: 1.05\n",
      "AUPRC (success): 95.04, std: 0.35\n",
      "AUPRC (error): 41.65, std: 2.08\n",
      "FPR95: 68.12, std: 1.45\n",
      "AURC: 54.08, std: 3.33\n",
      "Method: logits_margin\n",
      "AUROC: 85.24, std: 0.36\n",
      "AUPRC (success): 96.80, std: 0.09\n",
      "AUPRC (error): 47.22, std: 1.77\n",
      "FPR95: 63.42, std: 1.47\n",
      "AURC: 38.28, std: 1.17\n",
      "Method: negative_entropy\n",
      "AUROC: 85.07, std: 0.41\n",
      "AUPRC (success): 96.72, std: 0.10\n",
      "AUPRC (error): 48.54, std: 1.83\n",
      "FPR95: 62.76, std: 1.29\n",
      "AURC: 39.00, std: 1.21\n",
      "Method: negative_gini\n",
      "AUROC: 85.08, std: 0.40\n",
      "AUPRC (success): 96.70, std: 0.08\n",
      "AUPRC (error): 48.25, std: 1.83\n",
      "FPR95: 63.01, std: 1.45\n",
      "AURC: 39.07, std: 1.06\n"
     ]
    }
   ],
   "source": [
    "methods_list = ['softmax_margin','max_logits','logits_margin','negative_entropy','negative_gini']\n",
    "\n",
    "for method in methods_list:\n",
    "    print(f'Method: {method}')\n",
    "    all_auroc = []\n",
    "    all_auprc_succ = []\n",
    "    all_auprc_error = []\n",
    "    all_fpr95 = []\n",
    "    all_aurc = []\n",
    "    for run in range(5):\n",
    "        tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "        model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "\n",
    "        pred_y_val = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "        true_label_val = np.equal(true_y_val, pred_y_val).astype(int)\n",
    "        pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "        true_label_test = np.equal(true_y_test, pred_y_test).astype(int)\n",
    "\n",
    "        opt_temp = np.load(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/softmax_opt_temp.npy')\n",
    "        if method == 'softmax_margin':\n",
    "            conf_test = methods.softmax_margin(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "        elif method == 'max_logits':\n",
    "            conf_test = methods.max_logits(model, ds_test, cfg['batch_size'])\n",
    "        elif method == 'logits_margin':\n",
    "            conf_test = methods.logits_margin(model, ds_test, cfg['batch_size'])\n",
    "        elif method == 'negative_entropy':\n",
    "            conf_test = methods.negative_entropy(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "        elif method == 'negative_gini':\n",
    "            conf_test = methods.negative_gini(model, ds_test, opt_temp, cfg['batch_size'])\n",
    "\n",
    "        auroc = metrics.compute_auroc(true_label_test, conf_test)\n",
    "        auprc_succ = metrics.compute_auprc_success(true_label_test, conf_test)\n",
    "        auprc_error = metrics.compute_auprc_error(true_label_test, conf_test)\n",
    "        fpr95 = metrics.compute_fpr95(true_label_test, conf_test)\n",
    "        aurc = metrics.compute_aurc(true_label_test, conf_test)\n",
    "        all_auroc.append(auroc)\n",
    "        all_auprc_succ.append(auprc_succ)\n",
    "        all_auprc_error.append(auprc_error)\n",
    "        all_fpr95.append(fpr95)\n",
    "        all_aurc.append(aurc)\n",
    "    print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "    print(f'AUPRC (success): {np.mean(all_auprc_succ, axis=0)*100:.2f}, std: {np.std(all_auprc_succ, axis=0)*100:.2f}')\n",
    "    print(f'AUPRC (error): {np.mean(all_auprc_error, axis=0)*100:.2f}, std: {np.std(all_auprc_error, axis=0)*100:.2f}')\n",
    "    print(f'FPR95: {np.mean(all_fpr95, axis=0)*100:.2f}, std: {np.std(all_fpr95, axis=0)*100:.2f}')\n",
    "    print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b177e8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
