{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c51dd850",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-07-19 13:20:34.924083: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-07-19 13:20:34.924144: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-07-19 13:20:34.925096: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-07-19 13:20:34.931602: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6b0cbce",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'cifar10',\n",
    "        'model' : 'cnn',\n",
    "        'batch_size' : 512,\n",
    "        'optimizer' : 'Adam',\n",
    "        'learning_rate' : 0.001,\n",
    "        'epoch' : 100,\n",
    "        'epoch_save_period' : 1\n",
    "        }    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6b90ae39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average train error: 0.09, std: 0.03\n",
      "Average validation error: 1.86, std: 0.04\n",
      "Average test error: 1.94, std: 0.05\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute classification error\n",
    "#\n",
    "# #############################################################\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "ds_test = prefetch_dataset(ds_test, batch_size=cfg['batch_size'])\n",
    "\n",
    "train_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/trained_model.keras')\n",
    "    train_acc.append(model.evaluate(ds_train, verbose=0)[1])\n",
    "    val_acc.append(model.evaluate(ds_val, verbose=0)[1])\n",
    "    test_acc.append(model.evaluate(ds_test, verbose=0)[1])\n",
    "print(f'Average train error: {(100-np.mean(train_acc)*100):.2f}, std: {(np.std(train_acc)*100):.2f}')\n",
    "print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f}, std: {(np.std(val_acc)*100):.2f}')\n",
    "print(f'Average test error: {(100-np.mean(test_acc)*100):.2f}, std: {(np.std(test_acc)*100):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "779e75a3",
   "metadata": {},
   "source": [
    "### Softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3fb331c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-07-19 13:20:50.757476: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78419 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:0f:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-07-19 13:20:52.676927: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opt. threshold: 0.734, Test filtering error:17.78\n",
      "Run: 2\n",
      "Opt. threshold: 0.712, Test filtering error:16.67\n",
      "Run: 3\n",
      "Opt. threshold: 0.726, Test filtering error:16.33\n",
      "Run: 4\n",
      "Opt. threshold: 0.749, Test filtering error:17.07\n",
      "Run: 5\n",
      "Opt. threshold: 0.821, Test filtering error:18.25\n",
      "-----------------------------\n",
      "Average opt. threshold: 0.749, std: 0.038\n",
      "Average test filtering error: 17.22, std: 0.71\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (Softmax)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "filtering_acc = []\n",
    "threshold = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    preds = model.predict(ds_val.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    opt_threshold = compute_opt_threshold(softmax_val, true_label)\n",
    "    threshold.append(opt_threshold)\n",
    "    \n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    test_filtering_acc = compute_filtering_acc(softmax_val, true_label, opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "    \n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    np.savez(f'{exp_name}/softmax_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "17f5a72c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Average opt. threshold: 0.749, std: 0.038\n",
      "Average test filtering error: 17.22, std: 0.71\n"
     ]
    }
   ],
   "source": [
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    f = np.load(f'{exp_name}/softmax_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "622fd78a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Opt. threshold: 0.497, Test filtering error:17.57\n",
      "Run: 2\n",
      "Opt. threshold: 0.476, Test filtering error:16.53\n",
      "Run: 3\n",
      "Opt. threshold: 0.516, Test filtering error:16.19\n",
      "Run: 4\n",
      "Opt. threshold: 0.460, Test filtering error:17.10\n",
      "Run: 5\n",
      "Opt. threshold: 0.569, Test filtering error:17.96\n",
      "-----------------------------\n",
      "Average opt. threshold: 0.504, std: 0.038\n",
      "Average test filtering error: 17.07, std: 0.65\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (Softmax with Temperature Scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "\n",
    "filtering_acc = []\n",
    "threshold = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    temp = np.load(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/opt_temp.npy')\n",
    "    preds = model.predict(ds_val.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x/temp) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    opt_threshold = compute_opt_threshold(softmax_val, true_label)\n",
    "    threshold.append(opt_threshold)\n",
    "    \n",
    "    preds = model.predict(ds_test.batch(cfg['batch_size']), verbose=0)\n",
    "    preds = np.array([softmax(x/temp) for x in preds])\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(preds, axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    softmax_val = np.max(preds, axis=1)\n",
    "    test_filtering_acc = compute_filtering_acc(softmax_val, true_label, opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "    \n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    np.savez(f'{exp_name}/calibrated_softmax_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0a283e21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Average opt. threshold: 0.504, std: 0.038\n",
      "Average test filtering error: 17.07, std: 0.65\n"
     ]
    }
   ],
   "source": [
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration'\n",
    "    f = np.load(f'{exp_name}/calibrated_softmax_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6386da45",
   "metadata": {},
   "source": [
    "### PMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e416a849",
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "critic = 'separable'\n",
    "estimator = 'probabilistic_classifier'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Critic: {critic}, Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pmi_class = np.load(f'{exp_name}/pmi_class_val.npy')\n",
    "    pmi_class = np.array([softmax(x) for x in pmi_class])\n",
    "    pmi = [pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(pmi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pmi_class = np.load(f'{exp_name}/pmi_class_test.npy')\n",
    "    pmi_class = np.array([softmax(x) for x in pmi_class])\n",
    "    pmi = [pmi_value[pred_value] for pmi_value, pred_value in zip(pmi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(pmi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "21edb737",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Critic: separable, Estimator: probabilistic_classifier\n",
      "Average opt. threshold: 0.588, std: 0.043\n",
      "Average test filtering error: 16.58, std: 0.76\n"
     ]
    }
   ],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'probabilistic_classifier'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    f = np.load(f'{exp_name}/scaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Critic: {critic}, Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dacf186",
   "metadata": {},
   "source": [
    "### PVI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "8338eaec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: nan, Test filtering error:94.05\n",
      "Run: 2\n",
      "Estimator: training_from_scratch\n",
      "Opt. threshold: nan, Test filtering error:94.17\n",
      "Run: 3\n",
      "Estimator: training_from_scratch\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[29], line 19\u001b[0m\n\u001b[1;32m     17\u001b[0m ds_val \u001b[38;5;241m=\u001b[39m preprocess_dataset(ds_val, cfg, n_classes, resize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, onehot\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m     18\u001b[0m ds_test \u001b[38;5;241m=\u001b[39m preprocess_dataset(ds_test, cfg, n_classes, resize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, onehot\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m---> 19\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeras\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m../results/PI_Explainability/\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mmodel_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m_\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mdataset_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/run_\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mrun\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m/trained_model.keras\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m true_y \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39margmax([y \u001b[38;5;28;01mfor\u001b[39;00m x,y \u001b[38;5;129;01min\u001b[39;00m ds_val], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     22\u001b[0m pred_y \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39margmax(model\u001b[38;5;241m.\u001b[39mpredict(ds_val\u001b[38;5;241m.\u001b[39mbatch(cfg[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m'\u001b[39m]), verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m), axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py:254\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(filepath, custom_objects, compile, safe_mode, **kwargs)\u001b[0m\n\u001b[1;32m    249\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m kwargs:\n\u001b[1;32m    250\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    251\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe following argument(s) are not supported \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    252\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwith the native Keras format: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    253\u001b[0m         )\n\u001b[0;32m--> 254\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msaving_lib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    255\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfilepath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    256\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcustom_objects\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcustom_objects\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    257\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mcompile\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mcompile\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    258\u001b[0m \u001b[43m        \u001b[49m\u001b[43msafe_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msafe_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    259\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    261\u001b[0m \u001b[38;5;66;03m# Legacy case.\u001b[39;00m\n\u001b[1;32m    262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m legacy_sm_saving_lib\u001b[38;5;241m.\u001b[39mload_model(\n\u001b[1;32m    263\u001b[0m     filepath, custom_objects\u001b[38;5;241m=\u001b[39mcustom_objects, \u001b[38;5;28mcompile\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mcompile\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    264\u001b[0m )\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:269\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(filepath, custom_objects, compile, safe_mode)\u001b[0m\n\u001b[1;32m    266\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    267\u001b[0m     asset_store \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 269\u001b[0m \u001b[43m_load_state\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    270\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mweights_store\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweights_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[43massets_store\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43masset_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43minner_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvisited_trackables\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mset\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    276\u001b[0m weights_store\u001b[38;5;241m.\u001b[39mclose()\n\u001b[1;32m    277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m asset_store:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:466\u001b[0m, in \u001b[0;36m_load_state\u001b[0;34m(trackable, weights_store, assets_store, inner_path, skip_mismatch, visited_trackables)\u001b[0m\n\u001b[1;32m    457\u001b[0m     _load_state(\n\u001b[1;32m    458\u001b[0m         child_obj,\n\u001b[1;32m    459\u001b[0m         weights_store,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    463\u001b[0m         visited_trackables\u001b[38;5;241m=\u001b[39mvisited_trackables,\n\u001b[1;32m    464\u001b[0m     )\n\u001b[1;32m    465\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(child_obj, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mdict\u001b[39m, \u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mset\u001b[39m)):\n\u001b[0;32m--> 466\u001b[0m     \u001b[43m_load_container_state\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    467\u001b[0m \u001b[43m        \u001b[49m\u001b[43mchild_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    468\u001b[0m \u001b[43m        \u001b[49m\u001b[43mweights_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    469\u001b[0m \u001b[43m        \u001b[49m\u001b[43massets_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    470\u001b[0m \u001b[43m        \u001b[49m\u001b[43minner_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgfile\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43minner_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchild_attr\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    471\u001b[0m \u001b[43m        \u001b[49m\u001b[43mskip_mismatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_mismatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    472\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvisited_trackables\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvisited_trackables\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    473\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:534\u001b[0m, in \u001b[0;36m_load_container_state\u001b[0;34m(container, weights_store, assets_store, inner_path, skip_mismatch, visited_trackables)\u001b[0m\n\u001b[1;32m    532\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    533\u001b[0m     used_names[name] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m--> 534\u001b[0m \u001b[43m_load_state\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    535\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrackable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    536\u001b[0m \u001b[43m    \u001b[49m\u001b[43mweights_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    537\u001b[0m \u001b[43m    \u001b[49m\u001b[43massets_store\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    538\u001b[0m \u001b[43m    \u001b[49m\u001b[43minner_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgfile\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43minner_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    539\u001b[0m \u001b[43m    \u001b[49m\u001b[43mskip_mismatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_mismatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    540\u001b[0m \u001b[43m    \u001b[49m\u001b[43mvisited_trackables\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvisited_trackables\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    541\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:435\u001b[0m, in \u001b[0;36m_load_state\u001b[0;34m(trackable, weights_store, assets_store, inner_path, skip_mismatch, visited_trackables)\u001b[0m\n\u001b[1;32m    428\u001b[0m             warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m    429\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not load weights in object \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrackable\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    430\u001b[0m                 \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSkipping object. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    431\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mException encountered: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    432\u001b[0m                 stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m,\n\u001b[1;32m    433\u001b[0m             )\n\u001b[1;32m    434\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 435\u001b[0m         trackable\u001b[38;5;241m.\u001b[39mload_own_variables(\u001b[43mweights_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43minner_path\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(trackable, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mload_assets\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m assets_store:\n\u001b[1;32m    438\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m skip_mismatch:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:634\u001b[0m, in \u001b[0;36mH5IOStore.get\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m    632\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mh5_file[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvars\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m    633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m path \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mh5_file \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvars\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mh5_file[path]:\n\u001b[0;32m--> 634\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mh5_file\u001b[49m\u001b[43m[\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvars\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m    635\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {}\n",
      "File \u001b[0;32mh5py/_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mh5py/_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/h5py/_hl/group.py:328\u001b[0m, in \u001b[0;36mGroup.__getitem__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    326\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid HDF5 object reference\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    327\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(name, (\u001b[38;5;28mbytes\u001b[39m, \u001b[38;5;28mstr\u001b[39m)):\n\u001b[0;32m--> 328\u001b[0m     oid \u001b[38;5;241m=\u001b[39m \u001b[43mh5o\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_e\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlapl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_lapl\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    330\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAccessing a group is done with bytes or str, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    331\u001b[0m                     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m not \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(name)))\n",
      "File \u001b[0;32mh5py/_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mh5py/_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mh5py/h5o.pyx:190\u001b[0m, in \u001b[0;36mh5py.h5o.open\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mh5py/h5fd.pyx:160\u001b[0m, in \u001b[0;36mh5py.h5fd.H5FD_fileobj_read\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/usr/lib/python3.10/zipfile.py:1093\u001b[0m, in \u001b[0;36mZipExtFile.seek\u001b[0;34m(self, offset, whence)\u001b[0m\n\u001b[1;32m   1091\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m read_offset \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m   1092\u001b[0m     read_len \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mMAX_SEEK_READ, read_offset)\n\u001b[0;32m-> 1093\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mread_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1094\u001b[0m     read_offset \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m read_len\n\u001b[1;32m   1096\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtell()\n",
      "File \u001b[0;32m/usr/lib/python3.10/zipfile.py:927\u001b[0m, in \u001b[0;36mZipExtFile.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m    925\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_offset \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    926\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m n \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_eof:\n\u001b[0;32m--> 927\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    928\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m n \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mlen\u001b[39m(data):\n\u001b[1;32m    929\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_readbuffer \u001b[38;5;241m=\u001b[39m data\n",
      "File \u001b[0;32m/usr/lib/python3.10/zipfile.py:997\u001b[0m, in \u001b[0;36mZipExtFile._read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m    995\u001b[0m         data \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read2(n \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mlen\u001b[39m(data))\n\u001b[1;32m    996\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 997\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    999\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_type \u001b[38;5;241m==\u001b[39m ZIP_STORED:\n\u001b[1;32m   1000\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_eof \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_left \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
      "File \u001b[0;32m/usr/lib/python3.10/zipfile.py:1027\u001b[0m, in \u001b[0;36mZipExtFile._read2\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m   1024\u001b[0m n \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(n, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mMIN_READ_SIZE)\n\u001b[1;32m   1025\u001b[0m n \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(n, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_left)\n\u001b[0;32m-> 1027\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fileobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1028\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_left \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(data)\n\u001b[1;32m   1029\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m data:\n",
      "File \u001b[0;32m/usr/lib/python3.10/zipfile.py:747\u001b[0m, in \u001b[0;36m_SharedFile.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m    743\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt read from the ZIP file while there \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    744\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis an open writing handle on it. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    745\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClose the writing handle before trying to read.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    746\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file\u001b[38;5;241m.\u001b[39mseek(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pos)\n\u001b[0;32m--> 747\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    748\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pos \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_file\u001b[38;5;241m.\u001b[39mtell()\n\u001b[1;32m    749\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/tensorflow/python/lib/io/file_io.py:121\u001b[0m, in \u001b[0;36mFileIO.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    120\u001b[0m   length \u001b[38;5;241m=\u001b[39m n\n\u001b[0;32m--> 121\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_value(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_buf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlength\u001b[49m\u001b[43m)\u001b[49m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "estimator = 'training_from_scratch'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pvi_class = np.load(f'{exp_name}/pvi_calibrated_class_val.npy')\n",
    "    pvi_class = np.array([softmax(x) for x in pvi_class])\n",
    "    pvi = [pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(pvi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    pvi_class = np.load(f'{exp_name}/pvi_calibrated_class_test.npy')\n",
    "    pvi_class = np.array([softmax(x) for x in pvi_class])\n",
    "    pvi = [pvi_value[pred_value] for pvi_value, pred_value in zip(pvi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(pvi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/calibrated_scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "29a0f3ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[nan, nan, nan, ..., nan, nan, nan],\n",
       "       [nan, nan, nan, ..., nan, nan, nan],\n",
       "       [nan, nan, nan, ..., nan, nan, nan],\n",
       "       ...,\n",
       "       [nan, nan, nan, ..., nan, nan, nan],\n",
       "       [nan, nan, nan, ..., nan, nan, nan],\n",
       "       [nan, nan, nan, ..., nan, nan, nan]])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.load(f'{exp_name}/pvi_calibrated_class_test.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d3c45c34",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Estimator: training_from_scratch\n",
      "Average opt. threshold: 0.144, std: 0.046\n",
      "Average test filtering error: 16.69, std: 0.51\n"
     ]
    }
   ],
   "source": [
    "estimator = 'training_from_scratch'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pvi/{estimator}'\n",
    "    f = np.load(f'{exp_name}/calibrated_scaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de827d35",
   "metadata": {},
   "source": [
    "### PSI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b09c2193",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run: 1\n",
      "Estimator: gaussian\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-22 12:02:37.713537: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1883] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78835 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n",
      "2024-05-22 12:03:03.037671: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8906\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opt. threshold: 0.216, Test filtering error:4.19\n",
      "Run: 2\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.326, Test filtering error:4.45\n",
      "Run: 3\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.218, Test filtering error:4.56\n",
      "Run: 4\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.302, Test filtering error:4.64\n",
      "Run: 5\n",
      "Estimator: gaussian\n",
      "Opt. threshold: 0.297, Test filtering error:4.64\n"
     ]
    }
   ],
   "source": [
    "##############################################################\n",
    "#\n",
    "# Compute Filtering Accuracy (with softmax scaling)\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "estimator = 'gaussian'\n",
    "\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    print(f'Run: {run+1}')\n",
    "    print(f'Estimator: {estimator}')\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg, shuffle=False)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/trained_model.keras')\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_val], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_val.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    psi_class = np.load(f'{exp_name}/psi_class_val.npy')\n",
    "    psi_class = np.array([softmax(x) for x in psi_class])\n",
    "    psi = [psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y)]\n",
    "    opt_threshold = compute_opt_threshold(psi, true_label)\n",
    "\n",
    "    true_y = np.argmax([y for x,y in ds_test], axis=1)\n",
    "    pred_y = np.argmax(model.predict(ds_test.batch(cfg['batch_size']), verbose=0), axis=1)\n",
    "    true_label = np.equal(true_y, pred_y).astype(int) # assign 0 if true_y != pred_y, assign 1 if true_y == pred_y\n",
    "    psi_class = np.load(f'{exp_name}/psi_class_test.npy')\n",
    "    psi_class = np.array([softmax(x) for x in psi_class])\n",
    "    psi = [psi_value[pred_value] for psi_value, pred_value in zip(psi_class, pred_y)]\n",
    "    test_filtering_acc = compute_filtering_acc(psi, true_label, opt_threshold)\n",
    "\n",
    "    np.savez(f'{exp_name}/scaled_filtering_accuracy.npz', opt_threshold=opt_threshold, test_filtering_acc=test_filtering_acc)\n",
    "    print(f'Opt. threshold: {opt_threshold:.3f}, Test filtering error:{100-test_filtering_acc:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3d276df9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------\n",
      "Estimator: gaussian\n",
      "Average opt. threshold: 0.272, std: 0.046\n",
      "Average test filtering error: 4.49, std: 0.17\n"
     ]
    }
   ],
   "source": [
    "estimator = 'gaussian'\n",
    "\n",
    "threshold = []\n",
    "filtering_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/psi/{estimator}'\n",
    "    f = np.load(f'{exp_name}/scaled_filtering_accuracy.npz')\n",
    "    opt_threshold = f['opt_threshold']\n",
    "    test_filtering_acc = f['test_filtering_acc']\n",
    "    threshold.append(opt_threshold)\n",
    "    filtering_acc.append(test_filtering_acc)\n",
    "\n",
    "print('-----------------------------')\n",
    "print(f'Estimator: {estimator}')\n",
    "print(f'Average opt. threshold: {(np.mean(threshold)):.3f}, std: {(np.std(threshold)):.3f}')\n",
    "print(f'Average test filtering error: {(100-np.mean(filtering_acc)):.2f}, std: {(np.std(filtering_acc)):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b742711d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
