{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bab96b0a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-11 01:51:59.457836: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-11 01:51:59.457911: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-11 01:51:59.459531: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-11 01:51:59.467572: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.models import train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a0d0843d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'stl10',\n",
    "        'model' : 'vgg16',\n",
    "        'batch_size' : 64,\n",
    "        'optimizer' : 'Adam',\n",
    "        'learning_rate' : 0.00005,\n",
    "        'epoch' : 50,\n",
    "        'epoch_save_period' : 1\n",
    "        }    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "be4bcb70",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_2\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_2/saved_models\n",
      "Epoch 1/50\n",
      "67/67 [==============================] - 12s 104ms/step - loss: 1.6857 - accuracy: 0.3652 - val_loss: 1.1442 - val_accuracy: 0.5827\n",
      "Epoch 2/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 1.0338 - accuracy: 0.6153 - val_loss: 0.8493 - val_accuracy: 0.6867\n",
      "Epoch 3/50\n",
      "67/67 [==============================] - 7s 97ms/step - loss: 0.7121 - accuracy: 0.7546 - val_loss: 0.6428 - val_accuracy: 0.7827\n",
      "Epoch 4/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.5196 - accuracy: 0.8162 - val_loss: 0.6034 - val_accuracy: 0.8040\n",
      "Epoch 5/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.4302 - accuracy: 0.8562 - val_loss: 0.6235 - val_accuracy: 0.7933\n",
      "Epoch 6/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.3490 - accuracy: 0.8899 - val_loss: 0.6077 - val_accuracy: 0.8173\n",
      "Epoch 7/50\n",
      "67/67 [==============================] - 7s 97ms/step - loss: 0.2603 - accuracy: 0.9125 - val_loss: 0.5392 - val_accuracy: 0.8253\n",
      "Epoch 8/50\n",
      "67/67 [==============================] - 7s 97ms/step - loss: 0.2180 - accuracy: 0.9252 - val_loss: 0.5615 - val_accuracy: 0.8280\n",
      "Epoch 9/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.2000 - accuracy: 0.9336 - val_loss: 0.5231 - val_accuracy: 0.8400\n",
      "Epoch 10/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1746 - accuracy: 0.9468 - val_loss: 0.4443 - val_accuracy: 0.8747\n",
      "Epoch 11/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1483 - accuracy: 0.9525 - val_loss: 0.5505 - val_accuracy: 0.8427\n",
      "Epoch 12/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1162 - accuracy: 0.9621 - val_loss: 0.4707 - val_accuracy: 0.8733\n",
      "Epoch 13/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1113 - accuracy: 0.9635 - val_loss: 0.4199 - val_accuracy: 0.8920\n",
      "Epoch 14/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.0517 - accuracy: 0.9831 - val_loss: 0.4221 - val_accuracy: 0.8907\n",
      "Epoch 15/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0474 - accuracy: 0.9859 - val_loss: 0.5062 - val_accuracy: 0.8720\n",
      "Epoch 16/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0404 - accuracy: 0.9878 - val_loss: 0.4832 - val_accuracy: 0.9027\n",
      "Epoch 17/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0343 - accuracy: 0.9894 - val_loss: 0.4735 - val_accuracy: 0.9040\n",
      "Epoch 18/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0323 - accuracy: 0.9915 - val_loss: 0.5815 - val_accuracy: 0.8773\n",
      "Epoch 19/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0283 - accuracy: 0.9911 - val_loss: 0.4455 - val_accuracy: 0.8907\n",
      "Epoch 20/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0345 - accuracy: 0.9882 - val_loss: 0.4777 - val_accuracy: 0.8867\n",
      "Epoch 21/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0106 - accuracy: 0.9972 - val_loss: 0.5786 - val_accuracy: 0.8893\n",
      "Epoch 22/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0162 - accuracy: 0.9941 - val_loss: 0.5912 - val_accuracy: 0.8720\n",
      "Epoch 23/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0175 - accuracy: 0.9953 - val_loss: 0.4806 - val_accuracy: 0.8947\n",
      "Epoch 24/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0370 - accuracy: 0.9885 - val_loss: 0.5855 - val_accuracy: 0.8520\n",
      "Epoch 25/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0433 - accuracy: 0.9861 - val_loss: 0.6011 - val_accuracy: 0.8667\n",
      "Epoch 26/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0321 - accuracy: 0.9911 - val_loss: 0.5970 - val_accuracy: 0.8733\n",
      "Epoch 27/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0298 - accuracy: 0.9913 - val_loss: 0.4632 - val_accuracy: 0.9187\n",
      "Epoch 28/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0041 - accuracy: 0.9993 - val_loss: 0.4732 - val_accuracy: 0.9173\n",
      "Epoch 29/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0043 - accuracy: 0.9991 - val_loss: 0.4252 - val_accuracy: 0.9173\n",
      "Epoch 30/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0098 - accuracy: 0.9972 - val_loss: 0.4654 - val_accuracy: 0.8987\n",
      "Epoch 31/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0085 - accuracy: 0.9981 - val_loss: 0.5156 - val_accuracy: 0.8987\n",
      "Epoch 32/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0076 - accuracy: 0.9981 - val_loss: 0.4197 - val_accuracy: 0.8947\n",
      "Epoch 33/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0021 - accuracy: 0.9998 - val_loss: 0.5929 - val_accuracy: 0.8933\n",
      "Epoch 34/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.4980 - val_accuracy: 0.9173\n",
      "Epoch 35/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 5.0523e-04 - accuracy: 1.0000 - val_loss: 0.5274 - val_accuracy: 0.9133\n",
      "Epoch 36/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 2.2784e-04 - accuracy: 1.0000 - val_loss: 0.5566 - val_accuracy: 0.9120\n",
      "Epoch 37/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 1.0860e-04 - accuracy: 1.0000 - val_loss: 0.5511 - val_accuracy: 0.9160\n",
      "Epoch 38/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0331 - accuracy: 0.9925 - val_loss: 0.6543 - val_accuracy: 0.8640\n",
      "Epoch 39/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0440 - accuracy: 0.9871 - val_loss: 0.6274 - val_accuracy: 0.8627\n",
      "Epoch 40/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0734 - accuracy: 0.9765 - val_loss: 0.5905 - val_accuracy: 0.8493\n",
      "Epoch 41/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0207 - accuracy: 0.9936 - val_loss: 0.5083 - val_accuracy: 0.8880\n",
      "Epoch 42/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0151 - accuracy: 0.9962 - val_loss: 0.5354 - val_accuracy: 0.8867\n",
      "Epoch 43/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0160 - accuracy: 0.9969 - val_loss: 0.4972 - val_accuracy: 0.9000\n",
      "Epoch 44/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0219 - accuracy: 0.9941 - val_loss: 0.6120 - val_accuracy: 0.8720\n",
      "Epoch 45/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0639 - accuracy: 0.9821 - val_loss: 0.7480 - val_accuracy: 0.8440\n",
      "Epoch 46/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0185 - accuracy: 0.9955 - val_loss: 0.7329 - val_accuracy: 0.8720\n",
      "Epoch 47/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0236 - accuracy: 0.9920 - val_loss: 0.5448 - val_accuracy: 0.8733\n",
      "Epoch 48/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0110 - accuracy: 0.9969 - val_loss: 0.5255 - val_accuracy: 0.9000\n",
      "Epoch 49/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0087 - accuracy: 0.9969 - val_loss: 0.4606 - val_accuracy: 0.8973\n",
      "Epoch 50/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0049 - accuracy: 0.9984 - val_loss: 0.4109 - val_accuracy: 0.9147\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_3\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_3/saved_models\n",
      "Epoch 1/50\n",
      "67/67 [==============================] - 18s 112ms/step - loss: 1.7730 - accuracy: 0.3438 - val_loss: 1.2333 - val_accuracy: 0.5587\n",
      "Epoch 2/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 1.1019 - accuracy: 0.5918 - val_loss: 0.9525 - val_accuracy: 0.6773\n",
      "Epoch 3/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.7743 - accuracy: 0.7264 - val_loss: 0.7589 - val_accuracy: 0.7560\n",
      "Epoch 4/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.5607 - accuracy: 0.8087 - val_loss: 0.5864 - val_accuracy: 0.7960\n",
      "Epoch 5/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.4273 - accuracy: 0.8602 - val_loss: 0.5521 - val_accuracy: 0.7947\n",
      "Epoch 6/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.3392 - accuracy: 0.8859 - val_loss: 0.5506 - val_accuracy: 0.8227\n",
      "Epoch 7/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.2890 - accuracy: 0.9071 - val_loss: 0.4495 - val_accuracy: 0.8600\n",
      "Epoch 8/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.2201 - accuracy: 0.9228 - val_loss: 0.5132 - val_accuracy: 0.8467\n",
      "Epoch 9/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.1989 - accuracy: 0.9346 - val_loss: 0.4692 - val_accuracy: 0.8587\n",
      "Epoch 10/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.1802 - accuracy: 0.9376 - val_loss: 0.4605 - val_accuracy: 0.8560\n",
      "Epoch 11/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.1386 - accuracy: 0.9527 - val_loss: 0.4243 - val_accuracy: 0.8693\n",
      "Epoch 12/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.1185 - accuracy: 0.9631 - val_loss: 0.4248 - val_accuracy: 0.8760\n",
      "Epoch 13/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.1075 - accuracy: 0.9654 - val_loss: 0.4467 - val_accuracy: 0.8800\n",
      "Epoch 14/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0719 - accuracy: 0.9786 - val_loss: 0.5085 - val_accuracy: 0.8680\n",
      "Epoch 15/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0575 - accuracy: 0.9802 - val_loss: 0.4614 - val_accuracy: 0.8800\n",
      "Epoch 16/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0504 - accuracy: 0.9852 - val_loss: 0.3707 - val_accuracy: 0.8920\n",
      "Epoch 17/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0351 - accuracy: 0.9904 - val_loss: 0.4861 - val_accuracy: 0.8533\n",
      "Epoch 18/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0630 - accuracy: 0.9800 - val_loss: 0.4784 - val_accuracy: 0.8800\n",
      "Epoch 19/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0383 - accuracy: 0.9873 - val_loss: 0.4722 - val_accuracy: 0.8773\n",
      "Epoch 20/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0306 - accuracy: 0.9913 - val_loss: 0.4654 - val_accuracy: 0.8853\n",
      "Epoch 21/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0211 - accuracy: 0.9944 - val_loss: 0.5669 - val_accuracy: 0.8707\n",
      "Epoch 22/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0195 - accuracy: 0.9941 - val_loss: 0.4318 - val_accuracy: 0.9080\n",
      "Epoch 23/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0303 - accuracy: 0.9915 - val_loss: 0.5357 - val_accuracy: 0.8693\n",
      "Epoch 24/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0090 - accuracy: 0.9979 - val_loss: 0.4371 - val_accuracy: 0.9013\n",
      "Epoch 25/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0233 - accuracy: 0.9911 - val_loss: 0.4220 - val_accuracy: 0.8907\n",
      "Epoch 26/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0173 - accuracy: 0.9951 - val_loss: 0.4575 - val_accuracy: 0.8947\n",
      "Epoch 27/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0103 - accuracy: 0.9972 - val_loss: 0.4566 - val_accuracy: 0.8973\n",
      "Epoch 28/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0305 - accuracy: 0.9901 - val_loss: 0.7350 - val_accuracy: 0.8507\n",
      "Epoch 29/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0351 - accuracy: 0.9887 - val_loss: 0.5523 - val_accuracy: 0.8693\n",
      "Epoch 30/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0306 - accuracy: 0.9906 - val_loss: 0.5426 - val_accuracy: 0.8680\n",
      "Epoch 31/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0254 - accuracy: 0.9918 - val_loss: 0.4845 - val_accuracy: 0.8867\n",
      "Epoch 32/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0242 - accuracy: 0.9932 - val_loss: 0.5703 - val_accuracy: 0.8653\n",
      "Epoch 33/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0094 - accuracy: 0.9981 - val_loss: 0.4851 - val_accuracy: 0.8987\n",
      "Epoch 34/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0053 - accuracy: 0.9979 - val_loss: 0.4978 - val_accuracy: 0.8960\n",
      "Epoch 35/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0125 - accuracy: 0.9972 - val_loss: 0.4784 - val_accuracy: 0.8880\n",
      "Epoch 36/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0093 - accuracy: 0.9976 - val_loss: 0.6801 - val_accuracy: 0.8653\n",
      "Epoch 37/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0688 - accuracy: 0.9762 - val_loss: 0.3678 - val_accuracy: 0.8867\n",
      "Epoch 38/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0237 - accuracy: 0.9925 - val_loss: 0.5090 - val_accuracy: 0.8907\n",
      "Epoch 39/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0219 - accuracy: 0.9925 - val_loss: 0.4249 - val_accuracy: 0.8960\n",
      "Epoch 40/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0240 - accuracy: 0.9936 - val_loss: 0.4980 - val_accuracy: 0.8800\n",
      "Epoch 41/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0084 - accuracy: 0.9976 - val_loss: 0.4655 - val_accuracy: 0.9027\n",
      "Epoch 42/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0069 - accuracy: 0.9972 - val_loss: 0.5052 - val_accuracy: 0.9000\n",
      "Epoch 43/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0148 - accuracy: 0.9955 - val_loss: 0.4776 - val_accuracy: 0.9013\n",
      "Epoch 44/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0190 - accuracy: 0.9934 - val_loss: 0.4990 - val_accuracy: 0.8893\n",
      "Epoch 45/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0034 - accuracy: 0.9995 - val_loss: 0.4598 - val_accuracy: 0.9067\n",
      "Epoch 46/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0025 - accuracy: 0.9993 - val_loss: 0.5604 - val_accuracy: 0.8960\n",
      "Epoch 47/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 4.5868e-04 - accuracy: 1.0000 - val_loss: 0.5087 - val_accuracy: 0.9120\n",
      "Epoch 48/50\n",
      "67/67 [==============================] - 7s 104ms/step - loss: 3.6720e-04 - accuracy: 1.0000 - val_loss: 0.5664 - val_accuracy: 0.9013\n",
      "Epoch 49/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 2.7317e-04 - accuracy: 1.0000 - val_loss: 0.5384 - val_accuracy: 0.9147\n",
      "Epoch 50/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 8.1264e-05 - accuracy: 1.0000 - val_loss: 0.5465 - val_accuracy: 0.9173\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_4\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_4/saved_models\n",
      "Epoch 1/50\n",
      "67/67 [==============================] - 12s 105ms/step - loss: 1.5642 - accuracy: 0.4216 - val_loss: 1.0462 - val_accuracy: 0.6253\n",
      "Epoch 2/50\n",
      "67/67 [==============================] - 7s 97ms/step - loss: 0.9339 - accuracy: 0.6621 - val_loss: 0.6716 - val_accuracy: 0.7813\n",
      "Epoch 3/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.6711 - accuracy: 0.7696 - val_loss: 0.5963 - val_accuracy: 0.7987\n",
      "Epoch 4/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.5170 - accuracy: 0.8278 - val_loss: 0.6071 - val_accuracy: 0.7693\n",
      "Epoch 5/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.3854 - accuracy: 0.8696 - val_loss: 0.6746 - val_accuracy: 0.7680\n",
      "Epoch 6/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.2996 - accuracy: 0.8979 - val_loss: 0.5760 - val_accuracy: 0.8267\n",
      "Epoch 7/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.2464 - accuracy: 0.9186 - val_loss: 0.4436 - val_accuracy: 0.8600\n",
      "Epoch 8/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.1755 - accuracy: 0.9381 - val_loss: 0.4282 - val_accuracy: 0.8667\n",
      "Epoch 9/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.1420 - accuracy: 0.9539 - val_loss: 0.3748 - val_accuracy: 0.8893\n",
      "Epoch 10/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.1727 - accuracy: 0.9428 - val_loss: 0.4885 - val_accuracy: 0.8427\n",
      "Epoch 11/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1393 - accuracy: 0.9506 - val_loss: 0.5959 - val_accuracy: 0.8413\n",
      "Epoch 12/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1135 - accuracy: 0.9652 - val_loss: 0.4350 - val_accuracy: 0.8773\n",
      "Epoch 13/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0710 - accuracy: 0.9769 - val_loss: 0.4182 - val_accuracy: 0.8880\n",
      "Epoch 14/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0558 - accuracy: 0.9824 - val_loss: 0.4847 - val_accuracy: 0.8800\n",
      "Epoch 15/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0404 - accuracy: 0.9896 - val_loss: 0.6337 - val_accuracy: 0.8653\n",
      "Epoch 16/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0475 - accuracy: 0.9864 - val_loss: 0.4853 - val_accuracy: 0.8693\n",
      "Epoch 17/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0424 - accuracy: 0.9875 - val_loss: 0.5491 - val_accuracy: 0.8547\n",
      "Epoch 18/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0474 - accuracy: 0.9847 - val_loss: 0.5601 - val_accuracy: 0.8587\n",
      "Epoch 19/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0363 - accuracy: 0.9906 - val_loss: 0.7253 - val_accuracy: 0.8440\n",
      "Epoch 20/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0389 - accuracy: 0.9894 - val_loss: 0.4581 - val_accuracy: 0.8920\n",
      "Epoch 21/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0250 - accuracy: 0.9927 - val_loss: 0.5433 - val_accuracy: 0.8653\n",
      "Epoch 22/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0264 - accuracy: 0.9906 - val_loss: 0.5815 - val_accuracy: 0.8653\n",
      "Epoch 23/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0490 - accuracy: 0.9856 - val_loss: 0.4502 - val_accuracy: 0.8653\n",
      "Epoch 24/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0491 - accuracy: 0.9849 - val_loss: 0.6247 - val_accuracy: 0.8520\n",
      "Epoch 25/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0364 - accuracy: 0.9887 - val_loss: 0.7262 - val_accuracy: 0.8533\n",
      "Epoch 26/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0178 - accuracy: 0.9951 - val_loss: 0.5039 - val_accuracy: 0.8920\n",
      "Epoch 27/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0110 - accuracy: 0.9965 - val_loss: 0.5080 - val_accuracy: 0.8933\n",
      "Epoch 28/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0217 - accuracy: 0.9932 - val_loss: 0.4847 - val_accuracy: 0.9013\n",
      "Epoch 29/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0082 - accuracy: 0.9979 - val_loss: 0.5611 - val_accuracy: 0.8773\n",
      "Epoch 30/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0253 - accuracy: 0.9920 - val_loss: 0.5322 - val_accuracy: 0.8760\n",
      "Epoch 31/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0095 - accuracy: 0.9974 - val_loss: 0.5121 - val_accuracy: 0.9067\n",
      "Epoch 32/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 0.5366 - val_accuracy: 0.8973\n",
      "Epoch 33/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 6.7538e-04 - accuracy: 1.0000 - val_loss: 0.5371 - val_accuracy: 0.9027\n",
      "Epoch 34/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 2.2312e-04 - accuracy: 1.0000 - val_loss: 0.5443 - val_accuracy: 0.9080\n",
      "Epoch 35/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 1.9232e-04 - accuracy: 1.0000 - val_loss: 0.5451 - val_accuracy: 0.9107\n",
      "Epoch 36/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 1.6222e-04 - accuracy: 1.0000 - val_loss: 0.5796 - val_accuracy: 0.9053\n",
      "Epoch 37/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 4.3427e-04 - accuracy: 1.0000 - val_loss: 0.5694 - val_accuracy: 0.9040\n",
      "Epoch 38/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.0014 - accuracy: 0.9998 - val_loss: 0.6966 - val_accuracy: 0.8813\n",
      "Epoch 39/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0078 - accuracy: 0.9979 - val_loss: 0.5616 - val_accuracy: 0.8907\n",
      "Epoch 40/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0107 - accuracy: 0.9979 - val_loss: 0.5671 - val_accuracy: 0.8733\n",
      "Epoch 41/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0863 - accuracy: 0.9713 - val_loss: 0.4979 - val_accuracy: 0.8333\n",
      "Epoch 42/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0684 - accuracy: 0.9819 - val_loss: 0.5187 - val_accuracy: 0.8827\n",
      "Epoch 43/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0270 - accuracy: 0.9929 - val_loss: 0.4596 - val_accuracy: 0.8880\n",
      "Epoch 44/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0165 - accuracy: 0.9939 - val_loss: 0.5000 - val_accuracy: 0.8773\n",
      "Epoch 45/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0329 - accuracy: 0.9901 - val_loss: 0.5575 - val_accuracy: 0.8960\n",
      "Epoch 46/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0174 - accuracy: 0.9955 - val_loss: 0.5377 - val_accuracy: 0.8867\n",
      "Epoch 47/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0184 - accuracy: 0.9958 - val_loss: 0.5227 - val_accuracy: 0.8973\n",
      "Epoch 48/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0179 - accuracy: 0.9946 - val_loss: 0.6187 - val_accuracy: 0.8587\n",
      "Epoch 49/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0283 - accuracy: 0.9913 - val_loss: 0.5798 - val_accuracy: 0.8733\n",
      "Epoch 50/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.0123 - accuracy: 0.9967 - val_loss: 0.5053 - val_accuracy: 0.8973\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_5\n",
      "Making directory ../results/PI_Explainability/vgg16_stl10_224x224/run_5/saved_models\n",
      "Epoch 1/50\n",
      "67/67 [==============================] - 17s 108ms/step - loss: 1.6671 - accuracy: 0.3948 - val_loss: 1.1065 - val_accuracy: 0.6120\n",
      "Epoch 2/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.9727 - accuracy: 0.6544 - val_loss: 0.6946 - val_accuracy: 0.7600\n",
      "Epoch 3/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.7071 - accuracy: 0.7614 - val_loss: 0.7205 - val_accuracy: 0.7453\n",
      "Epoch 4/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.5290 - accuracy: 0.8198 - val_loss: 0.5683 - val_accuracy: 0.8173\n",
      "Epoch 5/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.4187 - accuracy: 0.8598 - val_loss: 0.4565 - val_accuracy: 0.8360\n",
      "Epoch 6/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.3396 - accuracy: 0.8913 - val_loss: 0.4940 - val_accuracy: 0.8427\n",
      "Epoch 7/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.2624 - accuracy: 0.9174 - val_loss: 0.4308 - val_accuracy: 0.8547\n",
      "Epoch 8/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.2400 - accuracy: 0.9191 - val_loss: 0.3995 - val_accuracy: 0.8600\n",
      "Epoch 9/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1788 - accuracy: 0.9421 - val_loss: 0.4963 - val_accuracy: 0.8427\n",
      "Epoch 10/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1307 - accuracy: 0.9569 - val_loss: 0.5393 - val_accuracy: 0.8480\n",
      "Epoch 11/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.1342 - accuracy: 0.9595 - val_loss: 0.5063 - val_accuracy: 0.8573\n",
      "Epoch 12/50\n",
      "67/67 [==============================] - 7s 99ms/step - loss: 0.1078 - accuracy: 0.9633 - val_loss: 0.4742 - val_accuracy: 0.8547\n",
      "Epoch 13/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0818 - accuracy: 0.9734 - val_loss: 0.6441 - val_accuracy: 0.8200\n",
      "Epoch 14/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0740 - accuracy: 0.9769 - val_loss: 0.4124 - val_accuracy: 0.8840\n",
      "Epoch 15/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0337 - accuracy: 0.9908 - val_loss: 0.5198 - val_accuracy: 0.8893\n",
      "Epoch 16/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0406 - accuracy: 0.9887 - val_loss: 0.4809 - val_accuracy: 0.8920\n",
      "Epoch 17/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0265 - accuracy: 0.9925 - val_loss: 0.5765 - val_accuracy: 0.8720\n",
      "Epoch 18/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0215 - accuracy: 0.9925 - val_loss: 0.4389 - val_accuracy: 0.9000\n",
      "Epoch 19/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0438 - accuracy: 0.9864 - val_loss: 0.6405 - val_accuracy: 0.8747\n",
      "Epoch 20/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.0316 - accuracy: 0.9934 - val_loss: 0.5117 - val_accuracy: 0.8947\n",
      "Epoch 21/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0531 - accuracy: 0.9824 - val_loss: 0.5715 - val_accuracy: 0.8640\n",
      "Epoch 22/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0516 - accuracy: 0.9831 - val_loss: 0.4507 - val_accuracy: 0.8827\n",
      "Epoch 23/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0316 - accuracy: 0.9925 - val_loss: 0.4746 - val_accuracy: 0.8747\n",
      "Epoch 24/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0144 - accuracy: 0.9965 - val_loss: 0.4708 - val_accuracy: 0.8947\n",
      "Epoch 25/50\n",
      "67/67 [==============================] - 7s 100ms/step - loss: 0.0354 - accuracy: 0.9880 - val_loss: 0.6075 - val_accuracy: 0.8800\n",
      "Epoch 26/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0247 - accuracy: 0.9932 - val_loss: 0.6467 - val_accuracy: 0.8773\n",
      "Epoch 27/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0279 - accuracy: 0.9915 - val_loss: 0.5122 - val_accuracy: 0.8840\n",
      "Epoch 28/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0151 - accuracy: 0.9965 - val_loss: 0.5116 - val_accuracy: 0.8893\n",
      "Epoch 29/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0108 - accuracy: 0.9969 - val_loss: 0.4444 - val_accuracy: 0.8973\n",
      "Epoch 30/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0390 - accuracy: 0.9887 - val_loss: 0.5598 - val_accuracy: 0.8773\n",
      "Epoch 31/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0187 - accuracy: 0.9948 - val_loss: 0.5000 - val_accuracy: 0.8987\n",
      "Epoch 32/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0153 - accuracy: 0.9946 - val_loss: 0.4703 - val_accuracy: 0.8947\n",
      "Epoch 33/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0248 - accuracy: 0.9936 - val_loss: 0.6030 - val_accuracy: 0.8680\n",
      "Epoch 34/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0022 - accuracy: 0.9998 - val_loss: 0.5340 - val_accuracy: 0.9027\n",
      "Epoch 35/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 6.9813e-04 - accuracy: 1.0000 - val_loss: 0.5376 - val_accuracy: 0.9040\n",
      "Epoch 36/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 2.0454e-04 - accuracy: 1.0000 - val_loss: 0.5554 - val_accuracy: 0.9053\n",
      "Epoch 37/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0041 - accuracy: 0.9988 - val_loss: 0.7644 - val_accuracy: 0.8680\n",
      "Epoch 38/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0514 - accuracy: 0.9856 - val_loss: 0.5973 - val_accuracy: 0.8587\n",
      "Epoch 39/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0201 - accuracy: 0.9941 - val_loss: 0.5833 - val_accuracy: 0.8800\n",
      "Epoch 40/50\n",
      "67/67 [==============================] - 7s 103ms/step - loss: 0.0132 - accuracy: 0.9967 - val_loss: 0.7191 - val_accuracy: 0.8427\n",
      "Epoch 41/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0329 - accuracy: 0.9896 - val_loss: 0.5394 - val_accuracy: 0.8760\n",
      "Epoch 42/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0311 - accuracy: 0.9906 - val_loss: 0.5234 - val_accuracy: 0.8720\n",
      "Epoch 43/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0255 - accuracy: 0.9941 - val_loss: 0.5550 - val_accuracy: 0.8840\n",
      "Epoch 44/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0197 - accuracy: 0.9939 - val_loss: 0.5534 - val_accuracy: 0.8800\n",
      "Epoch 45/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0081 - accuracy: 0.9979 - val_loss: 0.4960 - val_accuracy: 0.8933\n",
      "Epoch 46/50\n",
      "67/67 [==============================] - 7s 98ms/step - loss: 0.0086 - accuracy: 0.9972 - val_loss: 0.4734 - val_accuracy: 0.8867\n",
      "Epoch 47/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0079 - accuracy: 0.9974 - val_loss: 0.5606 - val_accuracy: 0.8813\n",
      "Epoch 48/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0214 - accuracy: 0.9934 - val_loss: 0.6796 - val_accuracy: 0.8613\n",
      "Epoch 49/50\n",
      "67/67 [==============================] - 7s 101ms/step - loss: 0.0095 - accuracy: 0.9974 - val_loss: 0.5256 - val_accuracy: 0.8960\n",
      "Epoch 50/50\n",
      "67/67 [==============================] - 7s 102ms/step - loss: 0.0066 - accuracy: 0.9976 - val_loss: 0.5280 - val_accuracy: 0.8960\n"
     ]
    }
   ],
   "source": [
    "for run in range(1,5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}_224x224/run_{run+1}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "        \n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, size=(224,224), resize=True, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, size=(224,224), resize=True, normalize=True, onehot=True)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "    ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "    model, history = train_model(ds_train=ds_train,\n",
    "                                 ds_val=ds_val,\n",
    "                                 cfg=cfg,\n",
    "                                 exp_name=exp_name,\n",
    "                                 save=True,\n",
    "                                 checkpoint=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db216b43",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 15:57:54.798400: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "333/333 [==============================] - 10s 22ms/step - loss: 0.0037 - accuracy: 0.9989\n",
      "59/59 [==============================] - 1s 14ms/step - loss: 0.9229 - accuracy: 0.8595\n",
      "79/79 [==============================] - 1s 12ms/step - loss: 0.9221 - accuracy: 0.8553\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0041 - accuracy: 0.9990\n",
      "59/59 [==============================] - 1s 10ms/step - loss: 0.8957 - accuracy: 0.8648\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.8952 - accuracy: 0.8689\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0037 - accuracy: 0.9988\n",
      "59/59 [==============================] - 1s 10ms/step - loss: 0.8128 - accuracy: 0.8723\n",
      "79/79 [==============================] - 1s 10ms/step - loss: 0.8759 - accuracy: 0.8655\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0040 - accuracy: 0.9988\n",
      "59/59 [==============================] - 1s 9ms/step - loss: 0.8439 - accuracy: 0.8648\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.8777 - accuracy: 0.8636\n",
      "333/333 [==============================] - 4s 9ms/step - loss: 0.0060 - accuracy: 0.9982\n",
      "59/59 [==============================] - 1s 9ms/step - loss: 0.9498 - accuracy: 0.8644\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.9392 - accuracy: 0.8610\n",
      "Average train error: 0.12, std: 0.03\n",
      "Average validation error: 13.49, std: 0.41\n",
      "Average test error: 13.71, std: 0.46\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "ds_test = prefetch_dataset(ds_test, batch_size=cfg['batch_size'])\n",
    "\n",
    "##############################################################\n",
    "#\n",
    "# Compute the train, validation, test error\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "train_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    train_acc.append(model.evaluate(ds_train, verbose=1)[1])\n",
    "    val_acc.append(model.evaluate(ds_val, verbose=1)[1])\n",
    "    test_acc.append(model.evaluate(ds_test, verbose=1)[1])\n",
    "print(f'Average train error: {(100-np.mean(train_acc)*100):.2f}, std: {(np.std(train_acc)*100):.2f}')\n",
    "print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f}, std: {(np.std(val_acc)*100):.2f}')\n",
    "print(f'Average test error: {(100-np.mean(test_acc)*100):.2f}, std: {(np.std(test_acc)*100):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd616cc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
