{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8e52ab75",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-04 13:05:26.225740: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-04 13:05:26.225862: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-04 13:05:26.228615: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-04 13:05:26.241646: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from src.models import mlp\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.psi_estimators import psi_bin_train, psi_bin_val_class, psi_gaussian_train, psi_gaussian_val_class\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6f21b35d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'fashion_mnist',\n",
    "        'model' : 'cnn',\n",
    "        'batch_size' : 512}    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd06863f",
   "metadata": {},
   "source": [
    "### Binning (Penultimate Layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dda59439",
   "metadata": {},
   "source": [
    "#### Vary the number of bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1d829f0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/binning'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output)\n",
    "    \n",
    "    print(f'Training PSI model (binning)...')\n",
    "    n_projs = 500\n",
    "    for n_bins in range(10, 101, 10):\n",
    "        print(f'N_bins: {n_bins}, N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        \n",
    "        ##############################################################\n",
    "        #\n",
    "        # Train PSI Model\n",
    "        #\n",
    "        ###############################################################\n",
    "\n",
    "        psi_data = psi_bin_train(x, y, n_projs, n_bins)\n",
    "        np.save(f'{exp_name}/binning_model_{n_bins}_bins_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "        ###################################################################\n",
    "        #\n",
    "        # Compute PSI for all validation and test samples for all classes\n",
    "        #\n",
    "        ###################################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/binning_model_{n_bins}_bins_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "371bc24a",
   "metadata": {},
   "source": [
    "#### Vary the number of projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8fa381ab",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-01 04:25:36.313220: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shelvia/ICML2024/src/psi_estimators.py:94: RuntimeWarning: divide by zero encountered in log\n",
      "  pmi = np.log(joint_probs / (x_marginal_probs[:, None] * y_marginal_probs + 1e-9))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 2\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 3\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 4\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 5\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/binning'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output)\n",
    "    \n",
    "    print(f'Training PSI model (binning)...')\n",
    "    n_bins = 30\n",
    "    for n_projs in [250,500,750,1000]:\n",
    "        print(f'N_bins: {n_bins}, N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        \n",
    "        ##############################################################\n",
    "        #\n",
    "        # Train PSI Model\n",
    "        #\n",
    "        ###############################################################\n",
    "\n",
    "        psi_data = psi_bin_train(x, y, n_projs, n_bins)\n",
    "        np.save(f'{exp_name}/binning_model_{n_bins}_bins_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "        ###################################################################\n",
    "        #\n",
    "        # Compute PSI for all validation and test samples for all classes\n",
    "        #\n",
    "        ###################################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/binning_model_{n_bins}_bins_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bdfb98c",
   "metadata": {},
   "source": [
    "### Gaussian"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b7ec956",
   "metadata": {},
   "source": [
    "#### Vary the number of projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c41324",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PSI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    print(f'Training PSI model (gaussian)...')\n",
    "    \n",
    "    for n_projs in [250,500,750,1000]:\n",
    "        print(f'N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "\n",
    "#         psi_data = psi_gaussian_train(x, y, n_projs)\n",
    "#         np.save(f'{exp_name}/gaussian_model_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PSI for all validation and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/gaussian_model_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_class_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a46a3af",
   "metadata": {},
   "source": [
    "### Compare Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "428a34e3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "N_projs: 250\n",
      "AUROC: 95.653\n",
      "AUPRC (success): 99.923\n",
      "AUPRC (error): 28.997\n",
      "AURC: 0.879\n",
      "N_projs: 500\n",
      "AUROC: 95.691\n",
      "AUPRC (success): 99.925\n",
      "AUPRC (error): 30.604\n",
      "AURC: 0.862\n",
      "N_projs: 750\n",
      "AUROC: 95.603\n",
      "AUPRC (success): 99.922\n",
      "AUPRC (error): 30.110\n",
      "AURC: 0.889\n",
      "N_projs: 1000\n",
      "AUROC: 95.837\n",
      "AUPRC (success): 99.928\n",
      "AUPRC (error): 30.838\n",
      "AURC: 0.833\n",
      "Run: 2\n",
      "N_projs: 250\n",
      "AUROC: 96.166\n",
      "AUPRC (success): 99.933\n",
      "AUPRC (error): 31.693\n",
      "AURC: 0.775\n",
      "N_projs: 500\n",
      "AUROC: 95.982\n",
      "AUPRC (success): 99.928\n",
      "AUPRC (error): 33.255\n",
      "AURC: 0.824\n",
      "N_projs: 750\n",
      "AUROC: 96.055\n",
      "AUPRC (success): 99.931\n",
      "AUPRC (error): 30.789\n",
      "AURC: 0.802\n",
      "N_projs: 1000\n",
      "AUROC: 96.141\n",
      "AUPRC (success): 99.933\n",
      "AUPRC (error): 31.527\n",
      "AURC: 0.782\n",
      "Run: 3\n",
      "N_projs: 250\n",
      "AUROC: 96.155\n",
      "AUPRC (success): 99.931\n",
      "AUPRC (error): 35.166\n",
      "AURC: 0.804\n",
      "N_projs: 500\n",
      "AUROC: 95.948\n",
      "AUPRC (success): 99.926\n",
      "AUPRC (error): 36.245\n",
      "AURC: 0.854\n",
      "N_projs: 750\n",
      "AUROC: 96.207\n",
      "AUPRC (success): 99.931\n",
      "AUPRC (error): 36.903\n",
      "AURC: 0.808\n",
      "N_projs: 1000\n",
      "AUROC: 96.230\n",
      "AUPRC (success): 99.932\n",
      "AUPRC (error): 35.604\n",
      "AURC: 0.796\n",
      "Run: 4\n",
      "N_projs: 250\n",
      "AUROC: 96.447\n",
      "AUPRC (success): 99.941\n",
      "AUPRC (error): 36.633\n",
      "AURC: 0.696\n",
      "N_projs: 500\n",
      "AUROC: 96.198\n",
      "AUPRC (success): 99.936\n",
      "AUPRC (error): 35.441\n",
      "AURC: 0.748\n",
      "N_projs: 750\n",
      "AUROC: 96.157\n",
      "AUPRC (success): 99.935\n",
      "AUPRC (error): 35.706\n",
      "AURC: 0.758\n",
      "N_projs: 1000\n",
      "AUROC: 96.281\n",
      "AUPRC (success): 99.938\n",
      "AUPRC (error): 35.600\n",
      "AURC: 0.731\n",
      "Run: 5\n",
      "N_projs: 250\n",
      "AUROC: 95.823\n",
      "AUPRC (success): 99.930\n",
      "AUPRC (error): 29.186\n",
      "AURC: 0.799\n",
      "N_projs: 500\n",
      "AUROC: 95.913\n",
      "AUPRC (success): 99.932\n",
      "AUPRC (error): 30.788\n",
      "AURC: 0.780\n",
      "N_projs: 750\n",
      "AUROC: 95.964\n",
      "AUPRC (success): 99.933\n",
      "AUPRC (error): 30.192\n",
      "AURC: 0.769\n",
      "N_projs: 1000\n",
      "AUROC: 95.871\n",
      "AUPRC (success): 99.931\n",
      "AUPRC (error): 30.013\n",
      "AURC: 0.789\n"
     ]
    }
   ],
   "source": [
    "estimator = 'gaussian' #'binning'\n",
    "# n_bins_list = range(10, 101, 10)\n",
    "n_bins = 30\n",
    "n_projs_list = [250,500,750,1000] #[500]\n",
    "\n",
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    true_y_test = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int) # assign 1 if true_y != pred_y, assign 0 if true_y == pred_y\n",
    "    y_val = np.argmax([y.numpy().astype(np.int32) for x,y in ds_val], axis=1)\n",
    "\n",
    "    auroc_list = []\n",
    "    auprc_succ_list = []\n",
    "    auprc_error_list = []\n",
    "    aurc_list = []\n",
    "#     for n_bins in n_bins_list:\n",
    "    for n_projs in n_projs_list:\n",
    "#         print(f'N_bins: {n_bins}')\n",
    "        print(f'N_projs: {n_projs}')\n",
    "\n",
    "#         psi_class_test = np.load(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_test.npy')\n",
    "        psi_class_test = np.load(f'{exp_name}/psi_class_{n_projs}_projs_test.npy')\n",
    "#         psi = np.load(f'{exp_name}/psi_class_{n_bins}_bins_{n_projs}_projs_test.npy')\n",
    "#         psi = np.load(f'{exp_name}/psi_pred_{n_projs}_projs_test.npy')\n",
    "        psi_class_test = np.array([softmax(x) for x in psi_class_test])\n",
    "        psi_test = np.array([psi_value[pred_value] for psi_value, pred_value in zip(psi_class_test, pred_y_test)])\n",
    "\n",
    "        auroc = compute_auroc(true_label_test, psi_test)\n",
    "        auprc_succ = compute_auprc_success(true_label_test, psi_test)\n",
    "        auprc_error = compute_auprc_error(true_label_test, psi_test)\n",
    "        aurc, _, _ = compute_aurc(true_label_test, psi_test)\n",
    "        auroc_list.append(auroc)\n",
    "        auprc_succ_list.append(auprc_succ)\n",
    "        auprc_error_list.append(auprc_error)\n",
    "        aurc_list.append(aurc)\n",
    "        print(f'AUROC: {auroc*100:.3f}')\n",
    "        print(f'AUPRC (success): {auprc_succ*100:.3f}')\n",
    "        print(f'AUPRC (error): {auprc_error*100:.3f}')\n",
    "        print(f'AURC: {aurc*1000:.3f}')\n",
    "    all_auroc.append(auroc_list)\n",
    "    all_auprc_succ.append(auprc_succ_list)\n",
    "    all_auprc_error.append(auprc_error_list)\n",
    "    all_aurc.append(aurc_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5436044c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N_projs: 250\n",
      "AUROC: 96.049, std: 0.280\n",
      "AUPRC (success): 99.932, std: 0.006\n",
      "AUPRC (error): 32.335, std: 0.006\n",
      "AURC: 0.790, std: 0.059\n",
      "N_projs: 500\n",
      "AUROC: 95.946, std: 0.162\n",
      "AUPRC (success): 99.929, std: 0.004\n",
      "AUPRC (error): 33.266, std: 0.004\n",
      "AURC: 0.814, std: 0.043\n",
      "N_projs: 750\n",
      "AUROC: 95.997, std: 0.214\n",
      "AUPRC (success): 99.930, std: 0.004\n",
      "AUPRC (error): 32.740, std: 0.004\n",
      "AURC: 0.805, std: 0.046\n",
      "N_projs: 1000\n",
      "AUROC: 96.072, std: 0.184\n",
      "AUPRC (success): 99.932, std: 0.003\n",
      "AUPRC (error): 32.717, std: 0.003\n",
      "AURC: 0.786, std: 0.033\n"
     ]
    }
   ],
   "source": [
    "mean_auroc = np.mean(all_auroc, axis=0)\n",
    "std_auroc = np.std(all_auroc, axis=0)\n",
    "mean_auprc_succ = np.mean(all_auprc_succ, axis=0)\n",
    "std_auprc_succ = np.std(all_auprc_succ, axis=0)\n",
    "mean_auprc_error = np.mean(all_auprc_error, axis=0)\n",
    "std_auprc_error = np.std(all_auprc_error, axis=0)\n",
    "mean_aurc = np.mean(all_aurc, axis=0)\n",
    "std_aurc = np.std(all_aurc, axis=0)\n",
    "# for i, n_bins in enumerate(n_bins_list):\n",
    "for i, n_projs in enumerate(n_projs_list):\n",
    "#     print(f'N_bins: {n_bins}')\n",
    "    print(f'N_projs: {n_projs}')\n",
    "    print(f'AUROC: {mean_auroc[i]*100:.3f}, std: {std_auroc[i]*100:.3f}')\n",
    "    print(f'AUPRC (success): {mean_auprc_succ[i]*100:.3f}, std: {std_auprc_succ[i]*100:.3f}')\n",
    "    print(f'AUPRC (error): {mean_auprc_error[i]*100:.3f}, std: {std_auprc_succ[i]*100:.3f}')\n",
    "    print(f'AURC: {mean_aurc[i]*1000:.3f}, std: {std_aurc[i]*1000:.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddb0fa5e",
   "metadata": {},
   "source": [
    "### Binning (Output Layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ece1ebfa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Training PSI model (binning)...\n",
      "N_bins: 10, N_projs: 500\n",
      "Computing PSI for all validation samples...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shelvia/ICML2024/src/psi_estimators.py:94: RuntimeWarning: divide by zero encountered in log\n",
      "  pmi = np.log(joint_probs / (x_marginal_probs[:, None] * y_marginal_probs + 1e-9))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PSI for all test samples...\n",
      "N_bins: 20, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 40, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 50, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 60, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 70, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 80, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 90, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 100, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 2\n",
      "Training PSI model (binning)...\n",
      "N_bins: 10, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 20, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 40, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 50, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 60, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 70, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 80, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 90, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 100, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 3\n",
      "Training PSI model (binning)...\n",
      "N_bins: 10, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 20, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 40, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 50, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 60, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 70, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 80, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 90, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 100, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 4\n",
      "Training PSI model (binning)...\n",
      "N_bins: 10, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 20, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 40, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 50, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 60, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 70, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 80, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 90, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 100, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 5\n",
      "Training PSI model (binning)...\n",
      "N_bins: 10, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 20, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 40, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 50, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 60, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 70, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 80, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 90, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 100, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/binning'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    print(f'Training PSI model (binning)...')\n",
    "    n_projs = 500\n",
    "    for n_bins in [10,20,40,50,60,70,80,90,100]:\n",
    "        print(f'N_bins: {n_bins}, N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        \n",
    "        ##############################################################\n",
    "        #\n",
    "        # Train PSI Model\n",
    "        #\n",
    "        ###############################################################\n",
    "\n",
    "        psi_data = psi_bin_train(x, y, n_projs, n_bins)\n",
    "        np.save(f'{exp_name}/binning_output_model_{n_bins}_bins_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "        ###################################################################\n",
    "        #\n",
    "        # Compute PSI for all validation and test samples for all classes\n",
    "        #\n",
    "        ###################################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/binning_output_model_{n_bins}_bins_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_bins}_bins_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_bins}_bins_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "66c230b4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-02 19:34:36.617204: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shelvia/ICML2024/src/psi_estimators.py:94: RuntimeWarning: divide by zero encountered in log\n",
      "  pmi = np.log(joint_probs / (x_marginal_probs[:, None] * y_marginal_probs + 1e-9))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 2\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 3\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 4\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 5\n",
      "Training PSI model (binning)...\n",
      "N_bins: 30, N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_bins: 30, N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/binning'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    print(f'Training PSI model (binning)...')\n",
    "    n_bins = 30\n",
    "    for n_projs in [250,500,750,1000]:\n",
    "        print(f'N_bins: {n_bins}, N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        \n",
    "        ##############################################################\n",
    "        #\n",
    "        # Train PSI Model\n",
    "        #\n",
    "        ###############################################################\n",
    "\n",
    "        psi_data = psi_bin_train(x, y, n_projs, n_bins)\n",
    "        np.save(f'{exp_name}/binning_output_model_{n_bins}_bins_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "        ###################################################################\n",
    "        #\n",
    "        # Compute PSI for all validation and test samples for all classes\n",
    "        #\n",
    "        ###################################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/binning_output_model_{n_bins}_bins_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_bins}_bins_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_bin_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_bins}_bins_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a3a536",
   "metadata": {},
   "source": [
    "### Gaussian (Output Layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5c17f06b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 2\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 3\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 4\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "Run: 5\n",
      "Training PSI model (gaussian)...\n",
      "N_projs: 250\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 500\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 750\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n",
      "N_projs: 1000\n",
      "Computing PSI for all validation samples...\n",
      "Computing PSI for all test samples...\n"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/gaussian'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PSI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    print(f'Training PSI model (gaussian)...')\n",
    "    \n",
    "    for n_projs in [250,500,750,1000]:\n",
    "        print(f'N_projs: {n_projs}')\n",
    "\n",
    "        ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.argmax(y, axis=1))).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "\n",
    "        psi_data = psi_gaussian_train(x, y, n_projs)\n",
    "        np.save(f'{exp_name}/gaussian_output_model_{n_projs}_projs.npy', psi_data)\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PSI for all validation and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "        psi_data = np.load(f'{exp_name}/gaussian_output_model_{n_projs}_projs.npy', allow_pickle=True).item()\n",
    "\n",
    "        print(f'Computing PSI for all validation samples...')\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        y = np.array([val.numpy() for val in y])\n",
    "        psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_projs}_projs_val.npy', np.array(psi_class))\n",
    "\n",
    "        print(f'Computing PSI for all test samples...')\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x),y)).unbatch()\n",
    "        x, y = zip(*ds_activity)\n",
    "        x = np.array([val.numpy() for val in x])\n",
    "        psi_class, pmi_arr = psi_gaussian_val_class(x, psi_data)\n",
    "        np.save(f'{exp_name}/psi_output_class_{n_projs}_projs_test.npy', np.array(psi_class))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f6089d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "N_bins: 10\n",
      "AUROC: 86.142\n",
      "AUPRC (success): 98.832\n",
      "AUPRC (error): 27.444\n",
      "AURC: 13.128\n",
      "N_bins: 20\n",
      "AUROC: 84.655\n",
      "AUPRC (success): 98.697\n",
      "AUPRC (error): 24.283\n",
      "AURC: 14.413\n",
      "N_bins: 30\n",
      "AUROC: 84.653\n",
      "AUPRC (success): 98.699\n",
      "AUPRC (error): 24.290\n",
      "AURC: 14.400\n",
      "N_bins: 40\n",
      "AUROC: 85.028\n",
      "AUPRC (success): 98.734\n",
      "AUPRC (error): 24.393\n",
      "AURC: 14.068\n",
      "N_bins: 50\n",
      "AUROC: 84.296\n",
      "AUPRC (success): 98.668\n",
      "AUPRC (error): 23.163\n",
      "AURC: 14.698\n",
      "N_bins: 60\n",
      "AUROC: 84.386\n",
      "AUPRC (success): 98.688\n",
      "AUPRC (error): 23.062\n",
      "AURC: 14.513\n",
      "N_bins: 70\n",
      "AUROC: 84.310\n",
      "AUPRC (success): 98.686\n",
      "AUPRC (error): 22.805\n",
      "AURC: 14.531\n",
      "N_bins: 80\n",
      "AUROC: 84.307\n",
      "AUPRC (success): 98.687\n",
      "AUPRC (error): 22.923\n",
      "AURC: 14.523\n",
      "N_bins: 90\n",
      "AUROC: 84.162\n",
      "AUPRC (success): 98.692\n",
      "AUPRC (error): 22.071\n",
      "AURC: 14.483\n",
      "N_bins: 100\n",
      "AUROC: 83.534\n",
      "AUPRC (success): 98.643\n",
      "AUPRC (error): 21.325\n",
      "AURC: 14.955\n",
      "Run: 2\n",
      "N_bins: 10\n",
      "AUROC: 87.238\n",
      "AUPRC (success): 98.965\n",
      "AUPRC (error): 27.955\n",
      "AURC: 11.989\n",
      "N_bins: 20\n",
      "AUROC: 85.679\n",
      "AUPRC (success): 98.785\n",
      "AUPRC (error): 26.417\n",
      "AURC: 13.698\n",
      "N_bins: 30\n",
      "AUROC: 85.309\n",
      "AUPRC (success): 98.760\n",
      "AUPRC (error): 25.097\n",
      "AURC: 13.941\n",
      "N_bins: 40\n",
      "AUROC: 85.459\n",
      "AUPRC (success): 98.777\n",
      "AUPRC (error): 25.474\n",
      "AURC: 13.777\n",
      "N_bins: 50\n",
      "AUROC: 84.882\n",
      "AUPRC (success): 98.731\n",
      "AUPRC (error): 24.879\n",
      "AURC: 14.215\n",
      "N_bins: 60\n",
      "AUROC: 85.249\n",
      "AUPRC (success): 98.769\n",
      "AUPRC (error): 24.999\n",
      "AURC: 13.852\n",
      "N_bins: 70\n",
      "AUROC: 85.397\n",
      "AUPRC (success): 98.785\n",
      "AUPRC (error): 25.136\n",
      "AURC: 13.707\n",
      "N_bins: 80\n",
      "AUROC: 85.057\n",
      "AUPRC (success): 98.748\n",
      "AUPRC (error): 24.581\n",
      "AURC: 14.056\n",
      "N_bins: 90\n",
      "AUROC: 84.792\n",
      "AUPRC (success): 98.729\n",
      "AUPRC (error): 24.500\n",
      "AURC: 14.240\n",
      "N_bins: 100\n",
      "AUROC: 84.631\n",
      "AUPRC (success): 98.728\n",
      "AUPRC (error): 23.552\n",
      "AURC: 14.252\n",
      "Run: 3\n",
      "N_bins: 10\n",
      "AUROC: 88.460\n",
      "AUPRC (success): 99.058\n",
      "AUPRC (error): 32.655\n",
      "AURC: 11.238\n",
      "N_bins: 20\n",
      "AUROC: 87.148\n",
      "AUPRC (success): 98.903\n",
      "AUPRC (error): 31.108\n",
      "AURC: 12.700\n",
      "N_bins: 30\n",
      "AUROC: 86.706\n",
      "AUPRC (success): 98.867\n",
      "AUPRC (error): 30.089\n",
      "AURC: 13.046\n",
      "N_bins: 40\n",
      "AUROC: 87.015\n",
      "AUPRC (success): 98.902\n",
      "AUPRC (error): 29.915\n",
      "AURC: 12.718\n",
      "N_bins: 50\n",
      "AUROC: 86.396\n",
      "AUPRC (success): 98.839\n",
      "AUPRC (error): 29.064\n",
      "AURC: 13.319\n",
      "N_bins: 60\n",
      "AUROC: 86.355\n",
      "AUPRC (success): 98.844\n",
      "AUPRC (error): 28.566\n",
      "AURC: 13.268\n",
      "N_bins: 70\n",
      "AUROC: 86.244\n",
      "AUPRC (success): 98.828\n",
      "AUPRC (error): 28.018\n",
      "AURC: 13.418\n",
      "N_bins: 80\n",
      "AUROC: 86.558\n",
      "AUPRC (success): 98.863\n",
      "AUPRC (error): 28.761\n",
      "AURC: 13.090\n",
      "N_bins: 90\n",
      "AUROC: 85.839\n",
      "AUPRC (success): 98.788\n",
      "AUPRC (error): 27.663\n",
      "AURC: 13.800\n",
      "N_bins: 100\n",
      "AUROC: 86.261\n",
      "AUPRC (success): 98.832\n",
      "AUPRC (error): 28.406\n",
      "AURC: 13.383\n",
      "Run: 4\n",
      "N_bins: 10\n",
      "AUROC: 88.369\n",
      "AUPRC (success): 99.044\n",
      "AUPRC (error): 29.408\n",
      "AURC: 11.403\n",
      "N_bins: 20\n",
      "AUROC: 87.686\n",
      "AUPRC (success): 98.974\n",
      "AUPRC (error): 28.535\n",
      "AURC: 12.065\n",
      "N_bins: 30\n",
      "AUROC: 86.951\n",
      "AUPRC (success): 98.904\n",
      "AUPRC (error): 27.096\n",
      "AURC: 12.734\n",
      "N_bins: 40\n",
      "AUROC: 86.883\n",
      "AUPRC (success): 98.895\n",
      "AUPRC (error): 26.878\n",
      "AURC: 12.820\n",
      "N_bins: 50\n",
      "AUROC: 86.746\n",
      "AUPRC (success): 98.894\n",
      "AUPRC (error): 26.693\n",
      "AURC: 12.827\n",
      "N_bins: 60\n",
      "AUROC: 86.539\n",
      "AUPRC (success): 98.861\n",
      "AUPRC (error): 26.304\n",
      "AURC: 13.139\n",
      "N_bins: 70\n",
      "AUROC: 86.546\n",
      "AUPRC (success): 98.879\n",
      "AUPRC (error): 26.078\n",
      "AURC: 12.976\n",
      "N_bins: 80\n",
      "AUROC: 86.761\n",
      "AUPRC (success): 98.885\n",
      "AUPRC (error): 26.412\n",
      "AURC: 12.916\n",
      "N_bins: 90\n",
      "AUROC: 86.287\n",
      "AUPRC (success): 98.839\n",
      "AUPRC (error): 26.227\n",
      "AURC: 13.345\n",
      "N_bins: 100\n",
      "AUROC: 86.287\n",
      "AUPRC (success): 98.849\n",
      "AUPRC (error): 25.935\n",
      "AURC: 13.254\n",
      "Run: 5\n",
      "N_bins: 10\n",
      "AUROC: 87.060\n",
      "AUPRC (success): 98.983\n",
      "AUPRC (error): 26.653\n",
      "AURC: 11.727\n",
      "N_bins: 20\n",
      "AUROC: 86.402\n",
      "AUPRC (success): 98.924\n",
      "AUPRC (error): 25.087\n",
      "AURC: 12.294\n",
      "N_bins: 30\n",
      "AUROC: 85.808\n",
      "AUPRC (success): 98.871\n",
      "AUPRC (error): 24.279\n",
      "AURC: 12.799\n",
      "N_bins: 40\n",
      "AUROC: 85.776\n",
      "AUPRC (success): 98.869\n",
      "AUPRC (error): 24.040\n",
      "AURC: 12.822\n",
      "N_bins: 50\n",
      "AUROC: 85.274\n",
      "AUPRC (success): 98.799\n",
      "AUPRC (error): 24.267\n",
      "AURC: 13.479\n",
      "N_bins: 60\n",
      "AUROC: 85.269\n",
      "AUPRC (success): 98.818\n",
      "AUPRC (error): 23.717\n",
      "AURC: 13.307\n",
      "N_bins: 70\n",
      "AUROC: 85.086\n",
      "AUPRC (success): 98.800\n",
      "AUPRC (error): 22.970\n",
      "AURC: 13.478\n",
      "N_bins: 80\n",
      "AUROC: 84.657\n",
      "AUPRC (success): 98.758\n",
      "AUPRC (error): 22.963\n",
      "AURC: 13.877\n",
      "N_bins: 90\n",
      "AUROC: 85.007\n",
      "AUPRC (success): 98.792\n",
      "AUPRC (error): 23.046\n",
      "AURC: 13.548\n",
      "N_bins: 100\n",
      "AUROC: 84.976\n",
      "AUPRC (success): 98.803\n",
      "AUPRC (error): 22.967\n",
      "AURC: 13.450\n"
     ]
    }
   ],
   "source": [
    "estimator = 'binning' #'gaussian'\n",
    "n_bins_list = range(10, 101, 10)\n",
    "# n_bins = 30\n",
    "# n_projs_list = [250,500,750,1000]\n",
    "n_projs = 500\n",
    "\n",
    "all_auroc = []\n",
    "all_auprc_succ = []\n",
    "all_auprc_error = []\n",
    "all_aurc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    \n",
    "    true_y_test = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y_test = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label_test = np.equal(true_y_test, pred_y_test).astype(int) # assign 1 if true_y != pred_y, assign 0 if true_y == pred_y\n",
    "    y_val = np.argmax([y.numpy().astype(np.int32) for x,y in ds_val], axis=1)\n",
    "\n",
    "    auroc_list = []\n",
    "    auprc_succ_list = []\n",
    "    auprc_error_list = []\n",
    "    aurc_list = []\n",
    "    for n_bins in n_bins_list:\n",
    "        print(f'N_bins: {n_bins}')\n",
    "#     for n_projs in n_projs_list:\n",
    "#         print(f'N_projs: {n_projs}')\n",
    "        psi_class_test = np.load(f'{exp_name}/psi_output_class_{n_bins}_bins_{n_projs}_projs_test.npy')\n",
    "#         psi_class_test = np.load(f'{exp_name}/psi_output_class_{n_projs}_projs_test.npy')\n",
    "        psi_class_test = np.array([softmax(x) for x in psi_class_test])\n",
    "        psi_test = np.array([psi_value[pred_value] for psi_value, pred_value in zip(psi_class_test, pred_y_test)])\n",
    "\n",
    "        auroc = compute_auroc(true_label_test, psi_test)\n",
    "        auprc_succ = compute_auprc_success(true_label_test, psi_test)\n",
    "        auprc_error = compute_auprc_error(true_label_test, psi_test)\n",
    "        aurc, _, _ = compute_aurc(true_label_test, psi_test)\n",
    "        auroc_list.append(auroc)\n",
    "        auprc_succ_list.append(auprc_succ)\n",
    "        auprc_error_list.append(auprc_error)\n",
    "        aurc_list.append(aurc)\n",
    "        print(f'AUROC: {auroc*100:.3f}')\n",
    "        print(f'AUPRC (success): {auprc_succ*100:.3f}')\n",
    "        print(f'AUPRC (error): {auprc_error*100:.3f}')\n",
    "        print(f'AURC: {aurc*1000:.3f}')\n",
    "    all_auroc.append(auroc_list)\n",
    "    all_auprc_succ.append(auprc_succ_list)\n",
    "    all_auprc_error.append(auprc_error_list)\n",
    "    all_aurc.append(aurc_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ccf8a712",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "N_bins: 10\n",
      "AUROC: 87.454, std: 0.869\n",
      "AUPRC (success): 98.976, std: 0.080\n",
      "AUPRC (error): 28.823, std: 0.080\n",
      "AURC: 11.897, std: 0.668\n",
      "N_bins: 20\n",
      "AUROC: 86.314, std: 1.071\n",
      "AUPRC (success): 98.857, std: 0.101\n",
      "AUPRC (error): 27.086, std: 0.101\n",
      "AURC: 13.034, std: 0.888\n",
      "N_bins: 30\n",
      "AUROC: 85.886, std: 0.856\n",
      "AUPRC (success): 98.820, std: 0.078\n",
      "AUPRC (error): 26.170, std: 0.078\n",
      "AURC: 13.384, std: 0.666\n",
      "N_bins: 40\n",
      "AUROC: 86.032, std: 0.787\n",
      "AUPRC (success): 98.835, std: 0.067\n",
      "AUPRC (error): 26.140, std: 0.067\n",
      "AURC: 13.241, std: 0.565\n",
      "N_bins: 50\n",
      "AUROC: 85.519, std: 0.920\n",
      "AUPRC (success): 98.786, std: 0.080\n",
      "AUPRC (error): 25.613, std: 0.080\n",
      "AURC: 13.708, std: 0.666\n",
      "N_bins: 60\n",
      "AUROC: 85.559, std: 0.794\n",
      "AUPRC (success): 98.796, std: 0.062\n",
      "AUPRC (error): 25.330, std: 0.062\n",
      "AURC: 13.616, std: 0.511\n",
      "N_bins: 70\n",
      "AUROC: 85.517, std: 0.806\n",
      "AUPRC (success): 98.795, std: 0.064\n",
      "AUPRC (error): 25.001, std: 0.064\n",
      "AURC: 13.622, std: 0.512\n",
      "N_bins: 80\n",
      "AUROC: 85.468, std: 1.003\n",
      "AUPRC (success): 98.788, std: 0.074\n",
      "AUPRC (error): 25.128, std: 0.074\n",
      "AURC: 13.692, std: 0.604\n",
      "N_bins: 90\n",
      "AUROC: 85.217, std: 0.758\n",
      "AUPRC (success): 98.768, std: 0.052\n",
      "AUPRC (error): 24.702, std: 0.052\n",
      "AURC: 13.883, std: 0.423\n",
      "N_bins: 100\n",
      "AUROC: 85.138, std: 1.043\n",
      "AUPRC (success): 98.771, std: 0.076\n",
      "AUPRC (error): 24.437, std: 0.076\n",
      "AURC: 13.859, std: 0.651\n"
     ]
    }
   ],
   "source": [
    "mean_auroc = np.mean(all_auroc, axis=0)\n",
    "std_auroc = np.std(all_auroc, axis=0)\n",
    "mean_auprc_succ = np.mean(all_auprc_succ, axis=0)\n",
    "std_auprc_succ = np.std(all_auprc_succ, axis=0)\n",
    "mean_auprc_error = np.mean(all_auprc_error, axis=0)\n",
    "std_auprc_error = np.std(all_auprc_error, axis=0)\n",
    "mean_aurc = np.mean(all_aurc, axis=0)\n",
    "std_aurc = np.std(all_aurc, axis=0)\n",
    "for i, n_bins in enumerate(n_bins_list):\n",
    "# for i, n_projs in enumerate(n_projs_list):\n",
    "    print(f'N_bins: {n_bins}')\n",
    "#     print(f'N_projs: {n_projs}')\n",
    "    print(f'AUROC: {mean_auroc[i]*100:.3f}, std: {std_auroc[i]*100:.3f}')\n",
    "    print(f'AUPRC (success): {mean_auprc_succ[i]*100:.3f}, std: {std_auprc_succ[i]*100:.3f}')\n",
    "    print(f'AUPRC (error): {mean_auprc_error[i]*100:.3f}, std: {std_auprc_succ[i]*100:.3f}')\n",
    "    print(f'AURC: {mean_aurc[i]*1000:.3f}, std: {std_aurc[i]*1000:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f06aaed",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
