{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "3780ad45",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.pmi_estimators import neural_pmi\n",
    "from src.psi_estimators import psi_gaussian_val_class\n",
    "from src.pvi_estimators import neural_pvi_class\n",
    "\n",
    "import src.utils as utils\n",
    "import src.metrics as metrics\n",
    "import src.methods as methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "649a06a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_test_ood, ds_info = tfds.load('omniglot',\n",
    "                             split='test',\n",
    "                             data_dir = '../tensorflow_datasets/',\n",
    "                             shuffle_files=False,\n",
    "                             as_supervised=True,\n",
    "                             with_info=True,)\n",
    "\n",
    "def preprocess_omniglot(image, label):\n",
    "    image = tf.image.resize(image, (28, 28))\n",
    "    if image.shape[-1] == 3:\n",
    "        image = tf.image.rgb_to_grayscale(image)\n",
    "    image = image / 255.0\n",
    "    image = 1 - image\n",
    "    image = tf.reshape(image, (28, 28, 1))\n",
    "    return image, label\n",
    "\n",
    "ds_test_ood = ds_test_ood.map(preprocess_omniglot, num_parallel_calls=tf.data.AUTOTUNE)\n",
    "\n",
    "(ds_train, ds_val, ds_test), ds_info = tfds.load('mnist',\n",
    "                                                 split=['train[:85%]', 'train[85%:]', 'test'],\n",
    "                                                 data_dir = '../tensorflow_datasets/',\n",
    "                                                 shuffle_files=False,\n",
    "                                                 as_supervised=True,\n",
    "                                                 with_info=True,)\n",
    "\n",
    "def preprocess_mnist(image, label):\n",
    "    image = tf.image.resize(image, (28, 28))\n",
    "    image = tf.cast(image, tf.float32) / 255.0\n",
    "    image = tf.reshape(image, (28, 28, 1))\n",
    "    return image, label\n",
    "\n",
    "ds_test = ds_test.map(preprocess_mnist, num_parallel_calls=tf.data.AUTOTUNE)\n",
    "\n",
    "def prefetch_and_cache_dataset(dataset,take_amount=None):\n",
    "    if take_amount is not None:\n",
    "        dataset = dataset.take(take_amount)\n",
    "    dataset = dataset.cache()\n",
    "    dataset = dataset.prefetch(tf.data.AUTOTUNE)\n",
    "    return dataset\n",
    "\n",
    "ds_test_ood = prefetch_and_cache_dataset(ds_test_ood, take_amount=5000)\n",
    "ds_test = prefetch_and_cache_dataset(ds_test, take_amount=5000)\n",
    "\n",
    "combined_ds_test = ds_test_ood.concatenate(ds_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "ede8b812",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 94.17, std: 0.47\n",
      "AURC: 203.53, std: 4.05\n"
     ]
    }
   ],
   "source": [
    "all_auroc = []\n",
    "all_aurc = []\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/mlp_mnist/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    true_y_test = np.array([y.numpy() for x,y in ds_test])\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(128), verbose=0), axis=1)\n",
    "    mask = np.concatenate([np.ones(5000, dtype=int), np.equal(true_y_test, pred_y_test).astype(int)])\n",
    "\n",
    "    true_label_test = np.concatenate((np.zeros(5000), np.ones(5000)))\n",
    "    softmax_test = methods.max_softmax_prob(model, combined_ds_test)\n",
    "    \n",
    "    indices_to_keep = ~((true_label_test == 1) & (mask == 0))\n",
    "    filtered_true_label_test = true_label_test[indices_to_keep]\n",
    "    filtered_softmax_test = softmax_test[indices_to_keep]\n",
    "    \n",
    "    auroc = metrics.compute_auroc(filtered_true_label_test, filtered_softmax_test)\n",
    "    aurc = metrics.compute_aurc(filtered_true_label_test, filtered_softmax_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b0303719",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 96.28, std: 0.33\n",
      "AURC: 173.53, std: 2.40\n"
     ]
    }
   ],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'variational_f_js'\n",
    "\n",
    "all_auroc = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/mlp_mnist/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/mlp_mnist/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    true_y_test = np.array([y.numpy() for x,y in ds_test])\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(128), verbose=0), axis=1)\n",
    "    mask = np.concatenate([np.ones(5000, dtype=int), np.equal(true_y_test, pred_y_test).astype(int)])\n",
    "    \n",
    "    pmi_model = tf.keras.models.load_model(f'{exp_name}/pmi_output_model')\n",
    "    pmi_class = []\n",
    "    for k in range(10):\n",
    "        ds_activity = combined_ds_test.batch(128).map(lambda x, y: (int_model(x), tf.one_hot(tf.fill([tf.shape(x)[0]], k), depth=10))).cache().prefetch(tf.data.AUTOTUNE)\n",
    "        pmi_list = []\n",
    "        for (x_batch, y_batch) in ds_activity:\n",
    "            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator)\n",
    "            pmi_list += np.array(pmi).tolist()\n",
    "        pmi_class.append(pmi_list)\n",
    "    np.save(f'{exp_name}/pmi_output_class_test_ood.npy', np.array(pmi_class).T)\n",
    "    \n",
    "    true_label_test = np.concatenate((np.zeros(5000), np.ones(5000)))\n",
    "    pmi_class_test = np.load(f'{exp_name}/pmi_output_class_test_ood.npy')\n",
    "    pmi_class_test = np.array([utils.softmax(x) for x in pmi_class_test])\n",
    "    pred_y_test = np.argmax(model.predict(combined_ds_test.batch(128), verbose=0), axis=1)\n",
    "    pmi_test = np.array([pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class_test, pred_y_test)])\n",
    "    \n",
    "    indices_to_keep = ~((true_label_test == 1) & (mask == 0))\n",
    "    filtered_true_label_test = true_label_test[indices_to_keep]\n",
    "    filtered_pmi_test = pmi_test[indices_to_keep]\n",
    "    \n",
    "    auroc = metrics.compute_auroc(true_label_test, pmi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, pmi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "ff43bebc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Projections: 500it [00:03, 147.09it/s]\n",
      "Projections: 500it [00:03, 137.24it/s]\n",
      "Projections: 500it [00:03, 129.89it/s]\n",
      "Projections: 500it [00:04, 117.58it/s]\n",
      "Projections: 500it [00:03, 136.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC: 95.80, std: 0.40\n",
      "AURC: 174.58, std: 2.53\n"
     ]
    }
   ],
   "source": [
    "n_projs = 500\n",
    "estimator = 'gaussian'\n",
    "\n",
    "all_auroc = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/mlp_mnist/run_{run+1}/calibration/psi/{estimator}'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/mlp_mnist/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    true_y_test = np.array([y.numpy() for x,y in ds_test])\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(128), verbose=0), axis=1)\n",
    "    mask = np.concatenate([np.ones(5000, dtype=int), np.equal(true_y_test, pred_y_test).astype(int)])\n",
    "    \n",
    "    psi_data = np.load(f'{exp_name}/gaussian_output_model_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "    ds_activity = combined_ds_test.batch(128).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "    x, y = zip(*ds_activity)\n",
    "    x = np.array([val.numpy() for val in x])\n",
    "    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "    np.save(f'{exp_name}/psi_output_class_{n_projs}_projs_test_ood.npy', np.array(psi_class))\n",
    "    \n",
    "    true_label_test = np.concatenate((np.zeros(5000), np.ones(5000)))\n",
    "    psi_class_test = np.load(f'{exp_name}/psi_output_class_{n_projs}_projs_test_ood.npy')\n",
    "    psi_class_test = np.array([utils.softmax(x) for x in psi_class_test])\n",
    "    pred_y_test = np.argmax(model.predict(combined_ds_test.batch(128), verbose=0), axis=1)\n",
    "    psi_test = np.array([psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y_test)])\n",
    "    \n",
    "    indices_to_keep = ~((true_label_test == 1) & (mask == 0))\n",
    "    filtered_true_label_test = true_label_test[indices_to_keep]\n",
    "    filtered_psi_test = psi_test[indices_to_keep]\n",
    "    \n",
    "    auroc = metrics.compute_auroc(true_label_test, psi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, psi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "32951fe3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 991us/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "AUROC: 75.77, std: 1.14\n",
      "AURC: 337.60, std: 9.49\n"
     ]
    }
   ],
   "source": [
    "random_runs = list(range(5))\n",
    "while any(random_runs[i] == i for i in range(5)):\n",
    "    np.random.shuffle(random_runs)\n",
    "\n",
    "all_auroc = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/mlp_mnist/run_{run+1}/calibration/pvi/training_from_scratch'\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/mlp_mnist/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    true_y_test = np.array([y.numpy() for x,y in ds_test])\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(128), verbose=0), axis=1)\n",
    "    mask = np.concatenate([np.ones(5000, dtype=int), np.equal(true_y_test, pred_y_test).astype(int)])\n",
    "    \n",
    "    pvi_model = tf.keras.models.load_model(f'{exp_name}/pvi_model.keras')\n",
    "    null_model = tf.keras.models.load_model(f'{exp_name}/pvi_null_model.keras')\n",
    "\n",
    "    pvi_class = neural_pvi_class(combined_ds_test.batch(128), pvi_model, null_model)\n",
    "    np.save(f'{exp_name}/pvi_class_test_ood.npy', np.array(pvi_class))\n",
    "    \n",
    "    true_label_test = np.concatenate((np.zeros(5000), np.ones(5000)))\n",
    "    pvi_class_test = np.load(f'{exp_name}/pvi_class_test_ood.npy')\n",
    "    pvi_class_test = np.array([utils.softmax(x) for x in pvi_class_test])\n",
    "    pred_y_test = np.argmax(model.predict(combined_ds_test.batch(128), verbose=0), axis=1)\n",
    "    pvi_test = np.array([pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y_test)])\n",
    "    \n",
    "    indices_to_keep = ~((true_label_test == 1) & (mask == 0))\n",
    "    filtered_true_label_test = true_label_test[indices_to_keep]\n",
    "    filtered_pvi_test = pvi_test[indices_to_keep]\n",
    "    \n",
    "    auroc = metrics.compute_auroc(true_label_test, pvi_test)\n",
    "    aurc = metrics.compute_aurc(true_label_test, pvi_test)\n",
    "    all_auroc.append(auroc)\n",
    "    all_aurc.append(aurc)\n",
    "print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "1c6934f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method: softmax_margin\n",
      "AUROC: 94.41, std: 0.43\n",
      "AURC: 200.85, std: 3.67\n",
      "Method: max_logits\n",
      "AUROC: 96.79, std: 0.31\n",
      "AURC: 175.49, std: 1.87\n",
      "Method: logits_margin\n",
      "AUROC: 96.35, std: 0.36\n",
      "AURC: 176.63, std: 2.49\n",
      "Method: negative_entropy\n",
      "AUROC: 96.46, std: 0.35\n",
      "AURC: 176.20, std: 2.42\n",
      "Method: negative_gini\n",
      "AUROC: 94.17, std: 0.47\n",
      "AURC: 203.51, std: 4.05\n"
     ]
    }
   ],
   "source": [
    "methods_list = ['softmax_margin','max_logits','logits_margin','negative_entropy','negative_gini']\n",
    "for method in methods_list:\n",
    "    print(f'Method: {method}')\n",
    "\n",
    "    all_auroc = []\n",
    "    all_aurc = []\n",
    "\n",
    "    for run in range(5):\n",
    "        tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "        model = tf.keras.models.load_model(f'../results/PI_Explainability/mlp_mnist/run_{run+1}/saved_models/trained_model.keras')\n",
    "\n",
    "        true_y_test = np.array([y.numpy() for x,y in ds_test])\n",
    "        pred_y_test = np.argmax(model.predict(ds_test.batch(128), verbose=0), axis=1)\n",
    "        mask = np.concatenate([np.ones(5000, dtype=int), np.equal(true_y_test, pred_y_test).astype(int)])\n",
    "\n",
    "        true_label_test = np.concatenate((np.zeros(5000), np.ones(5000)))\n",
    "        if method == 'softmax_margin':\n",
    "            conf_test = methods.softmax_margin(model, combined_ds_test, batch_size=128)\n",
    "        elif method == 'max_logits':\n",
    "            conf_test = methods.max_logits(model, combined_ds_test, batch_size=128)\n",
    "        elif method == 'logits_margin':\n",
    "            conf_test = methods.logits_margin(model, combined_ds_test, batch_size=128)\n",
    "        elif method == 'negative_entropy':\n",
    "            conf_test = methods.negative_entropy(model, combined_ds_test, batch_size=128)\n",
    "        elif method == 'negative_gini':\n",
    "            conf_test = methods.negative_gini(model, combined_ds_test, batch_size=128)\n",
    "\n",
    "        indices_to_keep = ~((true_label_test == 1) & (mask == 0))\n",
    "        filtered_true_label_test = true_label_test[indices_to_keep]\n",
    "        filtered_conf_test = conf_test[indices_to_keep]\n",
    "\n",
    "        auroc = metrics.compute_auroc(filtered_true_label_test, filtered_conf_test)\n",
    "        aurc = metrics.compute_aurc(filtered_true_label_test, filtered_conf_test)\n",
    "        all_auroc.append(auroc)\n",
    "        all_aurc.append(aurc)\n",
    "    print(f'AUROC: {np.mean(all_auroc, axis=0)*100:.2f}, std: {np.std(all_auroc, axis=0)*100:.2f}')\n",
    "    print(f'AURC: {np.mean(all_aurc, axis=0)*1000:.2f}, std: {np.std(all_aurc, axis=0)*1000:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa258ecb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
