{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97f6dee6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-07-30 15:50:09.303726: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-07-30 15:50:09.303861: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-07-30 15:50:09.306580: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-07-30 15:50:09.322306: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "73431df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'fashion_mnist',\n",
    "        'model' : 'cnn',\n",
    "        'batch_size' : 512,\n",
    "        'optimizer' : 'Adam',\n",
    "        'learning_rate' : tf.keras.optimizers.schedules.ExponentialDecay(0.001,decay_steps=100000,decay_rate=0.95,staircase=True),\n",
    "        'epoch' : 300,\n",
    "        'epoch_save_period' : 1\n",
    "        }    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f91f5931",
   "metadata": {},
   "source": [
    "### Softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "589ad83e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-07-19 13:19:39.207770: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78833 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Opt. threshold: 0.326, Test filtering error:1.55\n",
      "Run: 2\n",
      "Opt. threshold: 0.375, Test filtering error:1.53\n",
      "Run: 3\n",
      "Opt. threshold: 0.656, Test filtering error:1.49\n",
      "Run: 4\n",
      "Opt. threshold: 0.857, Test filtering error:1.44\n",
      "Run: 5\n",
      "Opt. threshold: 0.770, Test filtering error:1.40\n",
      "-----------------------------\n",
      "Average opt. threshold: 0.597, std: 0.211\n",
      "Average test filtering error: 1.48, std: 0.06\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (Softmax)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "filtering_acc = []\n",
    "threshold = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_val.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    opt_threshold = compute_opt_threshold(softmax_val, true_label)\n",
    "    threshold.append(opt_threshold)\n",
    "    \n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    test_filtering_acc = compute_filtering_acc(softmax_val, true_label, opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "    \n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    np.savez(f'{exp_name}/softmax_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cebc437a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Average opt. threshold: 0.597, std: 0.211\n",
      "Average test filtering error: 1.48, std: 0.06\n"
     ]
    }
   ],
   "source": [
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    f = np.load(f'{exp_name}/softmax_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "38e27153",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Opt. threshold: 0.318, Test filtering error:1.50\n",
      "Run: 2\n",
      "Opt. threshold: 0.265, Test filtering error:1.52\n",
      "Run: 3\n",
      "Opt. threshold: 0.449, Test filtering error:1.50\n",
      "Run: 4\n",
      "Opt. threshold: 0.605, Test filtering error:1.39\n",
      "Run: 5\n",
      "Opt. threshold: 0.638, Test filtering error:1.43\n",
      "-----------------------------\n",
      "Average opt. threshold: 0.455, std: 0.149\n",
      "Average test filtering error: 1.47, std: 0.05\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (Softmax with Temperature Scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "filtering_acc = []\n",
    "threshold = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    temp = np.load(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/opt_temp.npy')\n",
    "    preds = model.predict(ds_val.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x/temp) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    opt_threshold = compute_opt_threshold(softmax_val, true_label)\n",
    "    threshold.append(opt_threshold)\n",
    "    \n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x/temp) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    test_filtering_acc = compute_filtering_acc(softmax_val, true_label, opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "    \n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    np.savez(f'{exp_name}/calibrated_softmax_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "88268b8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Average opt. threshold: 0.455, std: 0.149\n",
      "Average test filtering error: 1.47, std: 0.05\n"
     ]
    }
   ],
   "source": [
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    f = np.load(f'{exp_name}/calibrated_softmax_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61d7c093",
   "metadata": {},
   "source": [
    "### PMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "068e7eb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Critic: separable, Estimator: density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-02-01 12:10:28.875073: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1883] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n",
      "2024-02-01 12:10:57.619039: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8906\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opt. threshold: 0.478, Test filtering error:4.29\n",
      "Run: 2\n",
      "Critic: separable, Estimator: density_ratio_fitting\n",
      "Opt. threshold: 0.431, Test filtering error:4.42\n",
      "Run: 3\n",
      "Critic: separable, Estimator: density_ratio_fitting\n",
      "Opt. threshold: 0.502, Test filtering error:4.38\n",
      "Run: 4\n",
      "Critic: separable, Estimator: density_ratio_fitting\n",
      "Opt. threshold: 0.391, Test filtering error:4.55\n",
      "Run: 5\n",
      "Critic: separable, Estimator: density_ratio_fitting\n",
      "Opt. threshold: 0.053, Test filtering error:4.81\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "critic = 'separable'\n",
    "estimator = 'probabilistic_classifier'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Critic: {critic}, Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pmi_class = np.load(f'{exp_name}/pmi_class_val.npy')\n",
    "    pmi_class = np.array([softmax(x) for x in pmi_class])\n",
    "    pmi = [pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(pmi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pmi_class = np.load(f'{exp_name}/pmi_class_test.npy')\n",
    "    pmi_class = np.array([softmax(x) for x in pmi_class])\n",
    "    pmi = [pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(pmi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a0573cd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Critic: separable, Estimator: probabilistic_classifier\n",
      "Average opt. threshold: 0.392, std: 0.135\n",
      "Average test filtering error: 1.47, std: 0.05\n"
     ]
    }
   ],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'probabilistic_classifier'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    f = np.load(f'{exp_name}/temp_scaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Critic: {critic}, Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e873b28",
   "metadata": {},
   "source": [
    "### PVI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "29645bbd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: 0.070, Test filtering error:1.38\n",
      "Run: 2\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: 0.280, Test filtering error:1.42\n",
      "Run: 3\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: 0.216, Test filtering error:1.47\n",
      "Run: 4\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: 0.081, Test filtering error:1.30\n",
      "Run: 5\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: 0.409, Test filtering error:1.61\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "estimator = 'training_from_scratch'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pvi_class = np.load(f'{exp_name}/pvi_class_val.npy')\n",
    "    pvi_class = np.array([softmax(x) for x in pvi_class])\n",
    "    pvi = [pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(pvi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pvi_class = np.load(f'{exp_name}/pvi_class_test.npy')\n",
    "    pvi_class = np.array([softmax(x) for x in pvi_class])\n",
    "    pvi = [pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(pvi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2cd8d0ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Estimator: training_from_scratch\n",
      "Average opt. threshold: 0.582, std: 0.722\n",
      "Average test filtering error: 1.44, std: 0.11\n"
     ]
    }
   ],
   "source": [
    "estimator = 'training_from_scratch'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "    f = np.load(f'{exp_name}/unscaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5931e7ef",
   "metadata": {},
   "source": [
    "### PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0383743f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Estimator: gaussian\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 12:02:37.713537: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1883] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n",
      "2024-05-22 12:03:03.037671: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8906\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opt. threshold: 0.216, Test filtering error:4.19\n",
      "Run: 2\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.326, Test filtering error:4.45\n",
      "Run: 3\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.218, Test filtering error:4.56\n",
      "Run: 4\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.302, Test filtering error:4.64\n",
      "Run: 5\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.297, Test filtering error:4.64\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "estimator = 'gaussian'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    psi_class = np.load(f'{exp_name}/psi_class_val.npy')\n",
    "    psi_class = np.array([softmax(x) for x in psi_class])\n",
    "    psi = [psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(psi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    psi_class = np.load(f'{exp_name}/psi_class_test.npy')\n",
    "    psi_class = np.array([softmax(x) for x in psi_class])\n",
    "    psi = [psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(psi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "72830927",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Estimator: gaussian\n",
      "Average opt. threshold: 0.272, std: 0.046\n",
      "Average test filtering error: 4.49, std: 0.17\n"
     ]
    }
   ],
   "source": [
    "estimator = 'gaussian'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "    f = np.load(f'{exp_name}/scaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4518ec1f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
