{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "551e05a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-11 01:46:03.494308: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-11 01:46:03.494377: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-11 01:46:03.496021: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-11 01:46:03.504186: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.models import train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b330cfc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'stl10',\n",
    "        'model' : 'pretrained_vgg16',\n",
    "        'batch_size' : 128,\n",
    "        'optimizer' : 'SGD',\n",
    "        'learning_rate' : 0.005,\n",
    "        'epoch' : 100,\n",
    "        'epoch_save_period' : 1\n",
    "        }    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1432540",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making directory ../results/PI_Explainability/pretrained_vgg16_stl10/run_1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-11 01:46:53.716448: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 77101 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:07:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5\n",
      " 2301952/58889256 [>.............................] - ETA: 28s"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "        \n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, size=(224,224), resize=True, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, size=(224,224), resize=True, normalize=True, onehot=True)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "    ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "    model, history = train_model(ds_train=ds_train,\n",
    "                                 ds_val=ds_val,\n",
    "                                 cfg=cfg,\n",
    "                                 exp_name=exp_name,\n",
    "                                 save=True,\n",
    "                                 checkpoint=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0fa5e8de",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-07 15:57:54.798400: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "333/333 [==============================] - 10s 22ms/step - loss: 0.0037 - accuracy: 0.9989\n",
      "59/59 [==============================] - 1s 14ms/step - loss: 0.9229 - accuracy: 0.8595\n",
      "79/79 [==============================] - 1s 12ms/step - loss: 0.9221 - accuracy: 0.8553\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0041 - accuracy: 0.9990\n",
      "59/59 [==============================] - 1s 10ms/step - loss: 0.8957 - accuracy: 0.8648\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.8952 - accuracy: 0.8689\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0037 - accuracy: 0.9988\n",
      "59/59 [==============================] - 1s 10ms/step - loss: 0.8128 - accuracy: 0.8723\n",
      "79/79 [==============================] - 1s 10ms/step - loss: 0.8759 - accuracy: 0.8655\n",
      "333/333 [==============================] - 4s 10ms/step - loss: 0.0040 - accuracy: 0.9988\n",
      "59/59 [==============================] - 1s 9ms/step - loss: 0.8439 - accuracy: 0.8648\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.8777 - accuracy: 0.8636\n",
      "333/333 [==============================] - 4s 9ms/step - loss: 0.0060 - accuracy: 0.9982\n",
      "59/59 [==============================] - 1s 9ms/step - loss: 0.9498 - accuracy: 0.8644\n",
      "79/79 [==============================] - 1s 9ms/step - loss: 0.9392 - accuracy: 0.8610\n",
      "Average train error: 0.12, std: 0.03\n",
      "Average validation error: 13.49, std: 0.41\n",
      "Average test error: 13.71, std: 0.46\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "ds_test = prefetch_dataset(ds_test, batch_size=cfg['batch_size'])\n",
    "\n",
    "##############################################################\n",
    "#\n",
    "# Compute the train, validation, test error\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "train_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    train_acc.append(model.evaluate(ds_train, verbose=1)[1])\n",
    "    val_acc.append(model.evaluate(ds_val, verbose=1)[1])\n",
    "    test_acc.append(model.evaluate(ds_test, verbose=1)[1])\n",
    "print(f'Average train error: {(100-np.mean(train_acc)*100):.2f}, std: {(np.std(train_acc)*100):.2f}')\n",
    "print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f}, std: {(np.std(val_acc)*100):.2f}')\n",
    "print(f'Average test error: {(100-np.mean(test_acc)*100):.2f}, std: {(np.std(test_acc)*100):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2306790e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
