{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4192cbce",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 23:08:03.228085: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-07 23:08:03.228149: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-07 23:08:03.229683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-07 23:08:03.237631: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.pmi_estimators import train_critic_model, neural_pmi\n",
    "from src.psi_estimators import psi_gaussian_train, psi_gaussian_val_class\n",
    "from src.pvi_estimators import train_pvi_null_model, neural_pvi_class\n",
    "\n",
    "import src.utils as utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "969ed319",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'mnist',\n",
    "        'model' : 'mlp',\n",
    "        'batch_size' : 128,\n",
    "        'optimizer' : 'SGD',\n",
    "        'learning_rate' : 0.005,\n",
    "        'epoch' : 50,\n",
    "        'epoch_save_period' : 1\n",
    "        }  \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cbb818af",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Training PMI model (separable, variational_f_js)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/200 [00:00<?, ?it/s]2024-08-07 01:31:46.331779: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f13268a4330 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2024-08-07 01:31:46.331823: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0\n",
      "2024-08-07 01:31:46.337558: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "2024-08-07 01:31:46.379926: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1722994306.474542 2813598 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_1/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   5%|▌         | 10/200 [18:00<5:42:18, 108.10s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all validation samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [26:27<00:00, 158.73s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all test samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [35:36<00:00, 213.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 2\n",
      "Making directory ../results/PI_Explainability/resnet50_cifar10/run_2/calibration/pmi/separable_variational_f_js\n",
      "Training PMI model (separable, variational_f_js)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_2/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   5%|▌         | 10/200 [17:36<5:34:38, 105.68s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all validation samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [26:34<00:00, 159.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all test samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [35:19<00:00, 211.94s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 3\n",
      "Making directory ../results/PI_Explainability/resnet50_cifar10/run_3/calibration/pmi/separable_variational_f_js\n",
      "Training PMI model (separable, variational_f_js)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_3/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   5%|▌         | 10/200 [17:48<5:38:12, 106.80s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all validation samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [26:11<00:00, 157.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all test samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [33:31<00:00, 201.17s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 4\n",
      "Making directory ../results/PI_Explainability/resnet50_cifar10/run_4/calibration/pmi/separable_variational_f_js\n",
      "Training PMI model (separable, variational_f_js)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_4/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   5%|▌         | 10/200 [18:26<5:50:25, 110.66s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all validation samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [27:24<00:00, 164.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all test samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [35:17<00:00, 211.73s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 5\n",
      "Making directory ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js\n",
      "Training PMI model (separable, variational_f_js)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   0%|          | 1/200 [17:56<59:30:49, 1076.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   1%|          | 2/200 [17:58<24:26:02, 444.25s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../results/PI_Explainability/resnet50_cifar10/run_5/calibration/pmi/separable_variational_f_js/pmi_output_model/assets\n",
      "Epochs:   6%|▌         | 12/200 [18:10<4:44:42, 90.86s/it] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all validation samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [27:16<00:00, 163.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PMI for all test samples and for all classes...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing classes: 100%|██████████| 10/10 [36:26<00:00, 218.62s/it]\n"
     ]
    }
   ],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'variational_f_js'\n",
    "\n",
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PMI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    print(f'Training PMI model ({critic}, {estimator})...')\n",
    "    ds_activity_trn = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)\n",
    "    ds_activity_val = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)\n",
    "    train_critic_model(ds_activity_trn, ds_activity_val, critic=critic, estimator=estimator, epochs=200, save_path=f'{exp_name}/pmi_output_model')\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PMI for all validation and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    pmi_model = tf.keras.models.load_model(f'{exp_name}/pmi_output_model')\n",
    "\n",
    "    print(f'Computing PMI for all validation samples and for all classes...')\n",
    "    pmi_class = []\n",
    "    for k in tqdm(range(n_classes), desc=\"Processing classes\"):\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.one_hot(tf.fill([tf.shape(x)[0]], k), depth=n_classes))).cache().prefetch(tf.data.AUTOTUNE)\n",
    "        pmi_list = []\n",
    "        for (x_batch, y_batch) in ds_activity:\n",
    "            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator)\n",
    "            pmi_list += np.array(pmi).tolist()\n",
    "        pmi_class.append(pmi_list)\n",
    "    np.save(f'{exp_name}/pmi_output_class_val.npy', np.array(pmi_class).T)\n",
    "\n",
    "    print(f'Computing PMI for all test samples and for all classes...')\n",
    "    pmi_class = []\n",
    "    for k in tqdm(range(n_classes), desc=\"Processing classes\"):\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.one_hot(tf.fill([tf.shape(x)[0]], k), depth=n_classes))).cache().prefetch(tf.data.AUTOTUNE)\n",
    "        pmi_list = []\n",
    "        for (x_batch, y_batch) in ds_activity:\n",
    "            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator)\n",
    "            pmi_list += np.array(pmi).tolist()\n",
    "        pmi_class.append(pmi_list)\n",
    "    np.save(f'{exp_name}/pmi_output_class_test.npy', np.array(pmi_class).T)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dbf1db8",
   "metadata": {},
   "source": [
    "### PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8ad27d68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Training PSI model (gaussian)...\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 2\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10/run_2/calibration/psi/gaussian\n",
      "Training PSI model (gaussian)...\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 3\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10/run_3/calibration/psi/gaussian\n",
      "Training PSI model (gaussian)...\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 4\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10/run_4/calibration/psi/gaussian\n",
      "Training PSI model (gaussian)...\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 5\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10/run_5/calibration/psi/gaussian\n",
      "Training PSI model (gaussian)...\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n"
     ]
    }
   ],
   "source": [
    "n_projs = 500\n",
    "estimator = 'gaussian'\n",
    "\n",
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PSI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    print(f'Training PSI model (gaussian)...')\n",
    "\n",
    "    ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "    x, y = zip(*ds_activity)\n",
    "    x = np.array([val.numpy() for val in x])\n",
    "    y = np.array([val.numpy() for val in y])\n",
    "\n",
    "    psi_data = psi_gaussian_train(x, y, n_projs)\n",
    "    np.save(f'{exp_name}/gaussian_output_model_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PSI for all validation and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    psi_data = np.load(f'{exp_name}/gaussian_output_model_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "    print(f'Computing PSI for all validation samples...')\n",
    "    ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "    x, y = zip(*ds_activity)\n",
    "    x = np.array([val.numpy() for val in x])\n",
    "    y = np.array([val.numpy() for val in y])\n",
    "    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "    np.save(f'{exp_name}/psi_output_class_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "    print(f'Computing PSI for all test samples...')\n",
    "    ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "    x, y = zip(*ds_activity)\n",
    "    x = np.array([val.numpy() for val in x])\n",
    "    psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "    np.save(f'{exp_name}/psi_output_class_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd3f8e38",
   "metadata": {},
   "source": [
    "### PVI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0241574a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 23:08:31.928931: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 23:08:34.390382: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f0b9d5acaf0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2024-08-07 23:08:34.390414: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0\n",
      "2024-08-07 23:08:34.429226: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1723072114.486637 3291842 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "399/399 [==============================] - 2s 3ms/step - loss: 2.3024 - accuracy: 0.1116\n",
      "Epoch 2/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3020 - accuracy: 0.1122\n",
      "Epoch 3/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3017 - accuracy: 0.1122\n",
      "Epoch 4/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3015 - accuracy: 0.1122\n",
      "Epoch 5/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3014 - accuracy: 0.1122\n",
      "Epoch 6/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 7/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 8/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 9/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 10/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 11/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 12/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 13/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 14/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 15/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 16/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 23:08:51.406667: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PVI for all validation samples and for all classes...\n",
      "71/71 [==============================] - 0s 1ms/step\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "Computing PVI for all test samples and for all classes...\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "Epoch 1/100\n",
      "399/399 [==============================] - 2s 3ms/step - loss: 2.3024 - accuracy: 0.1116\n",
      "Epoch 2/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3020 - accuracy: 0.1122\n",
      "Epoch 3/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3017 - accuracy: 0.1122\n",
      "Epoch 4/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3015 - accuracy: 0.1122\n",
      "Epoch 5/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3014 - accuracy: 0.1122\n",
      "Epoch 6/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 7/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 8/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 9/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 10/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 11/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 12/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 13/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 14/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 15/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 16/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Computing PVI for all validation samples and for all classes...\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "Computing PVI for all test samples and for all classes...\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 1ms/step\n",
      "Epoch 1/100\n",
      "399/399 [==============================] - 2s 2ms/step - loss: 2.3024 - accuracy: 0.1116\n",
      "Epoch 2/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3020 - accuracy: 0.1122\n",
      "Epoch 3/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3017 - accuracy: 0.1122\n",
      "Epoch 4/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3015 - accuracy: 0.1122\n",
      "Epoch 5/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3014 - accuracy: 0.1122\n",
      "Epoch 6/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 7/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 8/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 9/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 10/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 11/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 12/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 13/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 14/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 15/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 16/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Computing PVI for all validation samples and for all classes...\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "Computing PVI for all test samples and for all classes...\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "Epoch 1/100\n",
      "399/399 [==============================] - 2s 3ms/step - loss: 2.3024 - accuracy: 0.1116\n",
      "Epoch 2/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3020 - accuracy: 0.1122\n",
      "Epoch 3/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3017 - accuracy: 0.1122\n",
      "Epoch 4/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3015 - accuracy: 0.1122\n",
      "Epoch 5/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3014 - accuracy: 0.1122\n",
      "Epoch 6/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 7/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 8/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 9/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 10/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 11/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 12/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 13/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 14/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 15/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 16/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Computing PVI for all validation samples and for all classes...\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "Computing PVI for all test samples and for all classes...\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "Epoch 1/100\n",
      "399/399 [==============================] - 2s 3ms/step - loss: 2.3024 - accuracy: 0.1116\n",
      "Epoch 2/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3020 - accuracy: 0.1122\n",
      "Epoch 3/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3017 - accuracy: 0.1122\n",
      "Epoch 4/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3015 - accuracy: 0.1122\n",
      "Epoch 5/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3014 - accuracy: 0.1122\n",
      "Epoch 6/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 7/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3013 - accuracy: 0.1122\n",
      "Epoch 8/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 9/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 10/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 11/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 12/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 13/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 14/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 15/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Epoch 16/100\n",
      "399/399 [==============================] - 1s 2ms/step - loss: 2.3012 - accuracy: 0.1122\n",
      "Computing PVI for all validation samples and for all classes...\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "71/71 [==============================] - 0s 2ms/step\n",
      "Computing PVI for all test samples and for all classes...\n",
      "79/79 [==============================] - 0s 2ms/step\n",
      "79/79 [==============================] - 0s 2ms/step\n"
     ]
    }
   ],
   "source": [
    "random_runs = list(range(5))\n",
    "while any(random_runs[i] == i for i in range(5)):\n",
    "    np.random.shuffle(random_runs)\n",
    "    \n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/training_from_scratch'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "        \n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "        \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PVI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    pvi_model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{random_runs[run]+1}/saved_models/trained_model.keras')\n",
    "    pvi_model.save(f'{exp_name}/pvi_model.keras')\n",
    "    untrained_model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{random_runs[run]+1}/saved_models/untrained_model.keras')\n",
    "    train_pvi_null_model(ds_train, untrained_model, cfg, epochs=100, save_path=f'{exp_name}/pvi_null_model.keras')\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PVI for all training and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    pvi_model = tf.keras.models.load_model(f'{exp_name}/pvi_model.keras')\n",
    "    null_model = tf.keras.models.load_model(f'{exp_name}/pvi_null_model.keras')\n",
    "    \n",
    "    true_y_val = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    opt_temp_pvi = utils.temp_scaling_nll(pvi_model.predict(ds_val.batch(cfg['batch_size']), verbose=0), true_y_val)\n",
    "    ds_null = ds_val.map(lambda x, y: (tf.zeros_like(x), y))\n",
    "    opt_temp_null = utils.temp_scaling_nll(null_model.predict(ds_null.batch(cfg['batch_size']), verbose=0), true_y_val)\n",
    "\n",
    "    print(f'Computing PVI for all validation samples and for all classes...')\n",
    "    pvi_class = neural_pvi_class(ds_val.batch(cfg['batch_size']), pvi_model, null_model, opt_temp_pvi, opt_temp_null)\n",
    "    np.save(f'{exp_name}/pvi_class_val.npy', np.array(pvi_class))\n",
    "\n",
    "    print(f'Computing PVI for all test samples and for all classes...')\n",
    "    pvi_class = neural_pvi_class(ds_test.batch(cfg['batch_size']), pvi_model, null_model, opt_temp_pvi, opt_temp_null)\n",
    "    np.save(f'{exp_name}/pvi_class_test.npy', np.array(pvi_class))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ef1427",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
