{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "766b0c1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from src.models import mlp\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.pvi_estimators import train_pvi_null_model, train_pvi_model_from_scratch, neural_pvi_class, neural_pvi_ensemble_class\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e48e859b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'mnist',\n",
    "        'model' : 'mlp',\n",
    "        'batch_size' : 512,\n",
    "        'optimizer' : 'Adam',\n",
    "        'learning_rate' : tf.keras.optimizers.schedules.ExponentialDecay(0.001,decay_steps=100000,decay_rate=0.95,staircase=True),\n",
    "        'epoch' : 300,\n",
    "        'epoch_save_period' : 1\n",
    "        }     \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "175282a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = 0\n",
    "tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/no_training'\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "pvi_model = tf.keras.models.load_model(f'{exp_name}/pvi_model.keras')\n",
    "null_model = tf.keras.models.load_model(f'{exp_name}/pvi_null_model.keras')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0610042e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 94.890\n",
      "AUPRC (success): 99.867\n",
      "AUPRC (error): 38.041\n",
      "AURC: 1.425\n"
     ]
    }
   ],
   "source": [
    "ds_test = ds_test.batch(512)\n",
    "ds_null = ds_test.map(lambda x, y: (tf.zeros_like(x), y))\n",
    "prob_null = tf.nn.softmax(null_model.predict(ds_null, verbose=0)).numpy()\n",
    "v_null_entropy = -1 * np.log2(prob_null)\n",
    "prob_cond = tf.nn.softmax(pvi_model.predict(ds_test, verbose=0)).numpy()\n",
    "prob_cond = np.clip(prob_cond, 1e-40, 1.0)\n",
    "v_cond_entropy = -1 * np.log2(prob_cond)\n",
    "pvi_class = - v_cond_entropy\n",
    "pvi = np.array([val[pred_value] for val, pred_value in zip(pvi_class, pred_y)])\n",
    "\n",
    "auroc = compute_auroc(true_label, pvi)\n",
    "auprc_succ = compute_auprc_success(true_label, pvi)\n",
    "auprc_error = compute_auprc_error(true_label, pvi)\n",
    "aurc, _, _ = compute_aurc(true_label, pvi)\n",
    "print(f'AUROC: {auroc*100:.3f}')\n",
    "print(f'AUPRC (success): {auprc_succ*100:.3f}')\n",
    "print(f'AUPRC (error): {auprc_error*100:.3f}')\n",
    "print(f'AURC: {aurc*1000:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0cd4893f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 94.890\n",
      "AUPRC (success): 99.867\n",
      "AUPRC (error): 38.041\n",
      "AURC: 1.425\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "true_label = np.equal(true_y, pred_y).astype(int) # assign 1 if true_y != pred_y, assign 0 if true_y == pred_y\n",
    "softmax_class = tf.nn.softmax(model.predict(ds_test.batch(512), verbose=0))\n",
    "softmax_val = np.array([val[pred_value] for val, pred_value in zip(softmax_class, pred_y)])\n",
    "\n",
    "auroc = compute_auroc(true_label, softmax_val)\n",
    "auprc_succ = compute_auprc_success(true_label, softmax_val)\n",
    "auprc_error = compute_auprc_error(true_label, softmax_val)\n",
    "aurc, _, _ = compute_aurc(true_label, softmax_val)\n",
    "print(f'AUROC: {auroc*100:.3f}')\n",
    "print(f'AUPRC (success): {auprc_succ*100:.3f}')\n",
    "print(f'AUPRC (error): {auprc_error*100:.3f}')\n",
    "print(f'AURC: {aurc*1000:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c961ffd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
