{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "20cf5460",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-06 13:32:18.172514: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-06 13:32:18.172592: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-06 13:32:18.174761: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-06 13:32:18.184484: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.models import train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c540c5df",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'cifar10',\n",
    "        'model' : 'resnet50',\n",
    "        'batch_size' : 128,\n",
    "        'optimizer' : 'SGD',\n",
    "        'learning_rate' : 0.005,\n",
    "        'epoch' : 50,\n",
    "        'epoch_save_period' : 1\n",
    "        }    \n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47e3f2e2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-06 13:32:50.788580: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n",
      "2024-08-06 13:32:53.314477: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f755e511050 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2024-08-06 13:32:53.314524: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1722951173.400474  712906 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "333/333 [==============================] - 50s 62ms/step - loss: 1.5044 - accuracy: 0.4880 - val_loss: 1.4193 - val_accuracy: 0.5267\n",
      "Epoch 2/50\n",
      "333/333 [==============================] - 13s 39ms/step - loss: 0.6220 - accuracy: 0.7931 - val_loss: 0.6827 - val_accuracy: 0.7740\n",
      "Epoch 3/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.3353 - accuracy: 0.8888 - val_loss: 0.5419 - val_accuracy: 0.8241\n",
      "Epoch 4/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.1788 - accuracy: 0.9435 - val_loss: 0.5729 - val_accuracy: 0.8329\n",
      "Epoch 5/50\n",
      "333/333 [==============================] - 13s 39ms/step - loss: 0.0959 - accuracy: 0.9715 - val_loss: 0.6194 - val_accuracy: 0.8317\n",
      "Epoch 6/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0524 - accuracy: 0.9875 - val_loss: 0.6363 - val_accuracy: 0.8397\n",
      "Epoch 7/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0311 - accuracy: 0.9937 - val_loss: 0.6659 - val_accuracy: 0.8436\n",
      "Epoch 8/50\n",
      "333/333 [==============================] - 13s 39ms/step - loss: 0.0196 - accuracy: 0.9967 - val_loss: 0.6717 - val_accuracy: 0.8480\n",
      "Epoch 9/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0141 - accuracy: 0.9978 - val_loss: 0.7189 - val_accuracy: 0.8439\n",
      "Epoch 10/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0098 - accuracy: 0.9989 - val_loss: 0.7410 - val_accuracy: 0.8447\n",
      "Epoch 11/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 0.0079 - accuracy: 0.9988 - val_loss: 0.7645 - val_accuracy: 0.8436\n",
      "Epoch 12/50\n",
      "333/333 [==============================] - 15s 44ms/step - loss: 0.0062 - accuracy: 0.9994 - val_loss: 0.7881 - val_accuracy: 0.8407\n",
      "Epoch 13/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.0051 - accuracy: 0.9995 - val_loss: 0.8064 - val_accuracy: 0.8423\n",
      "Epoch 14/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.0040 - accuracy: 0.9996 - val_loss: 0.8082 - val_accuracy: 0.8463\n",
      "Epoch 15/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.0035 - accuracy: 0.9998 - val_loss: 0.8102 - val_accuracy: 0.8452\n",
      "Epoch 16/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0030 - accuracy: 0.9998 - val_loss: 0.8133 - val_accuracy: 0.8473\n",
      "Epoch 17/50\n",
      "333/333 [==============================] - 16s 49ms/step - loss: 0.0026 - accuracy: 0.9999 - val_loss: 0.8110 - val_accuracy: 0.8508\n",
      "Epoch 18/50\n",
      "333/333 [==============================] - 17s 50ms/step - loss: 0.0022 - accuracy: 0.9999 - val_loss: 0.8148 - val_accuracy: 0.8508\n",
      "Epoch 19/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 0.0020 - accuracy: 0.9999 - val_loss: 0.8151 - val_accuracy: 0.8536\n",
      "Epoch 20/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.0019 - accuracy: 0.9999 - val_loss: 0.8094 - val_accuracy: 0.8567\n",
      "Epoch 21/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 0.0017 - accuracy: 0.9999 - val_loss: 0.8198 - val_accuracy: 0.8555\n",
      "Epoch 22/50\n",
      "333/333 [==============================] - 13s 39ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.8159 - val_accuracy: 0.8597\n",
      "Epoch 23/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 0.8213 - val_accuracy: 0.8596\n",
      "Epoch 24/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 0.0013 - accuracy: 0.9999 - val_loss: 0.8319 - val_accuracy: 0.8596\n",
      "Epoch 25/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.8356 - val_accuracy: 0.8595\n",
      "Epoch 26/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.8384 - val_accuracy: 0.8588\n",
      "Epoch 27/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 9.4770e-04 - accuracy: 1.0000 - val_loss: 0.8430 - val_accuracy: 0.8588\n",
      "Epoch 28/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 9.1956e-04 - accuracy: 1.0000 - val_loss: 0.8467 - val_accuracy: 0.8592\n",
      "Epoch 29/50\n",
      "333/333 [==============================] - 14s 41ms/step - loss: 8.6708e-04 - accuracy: 1.0000 - val_loss: 0.8562 - val_accuracy: 0.8595\n",
      "Epoch 30/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 7.9962e-04 - accuracy: 1.0000 - val_loss: 0.8566 - val_accuracy: 0.8581\n",
      "Epoch 31/50\n",
      "333/333 [==============================] - 14s 43ms/step - loss: 7.2579e-04 - accuracy: 1.0000 - val_loss: 0.8631 - val_accuracy: 0.8593\n",
      "Epoch 32/50\n",
      "333/333 [==============================] - 13s 40ms/step - loss: 7.3074e-04 - accuracy: 1.0000 - val_loss: 0.8622 - val_accuracy: 0.8583\n",
      "Epoch 33/50\n",
      "333/333 [==============================] - 14s 42ms/step - loss: 6.6892e-04 - accuracy: 1.0000 - val_loss: 0.8645 - val_accuracy: 0.8599\n",
      "Epoch 34/50\n",
      "333/333 [==============================] - 14s 43ms/step - loss: 5.9241e-04 - accuracy: 1.0000 - val_loss: 0.8715 - val_accuracy: 0.8584\n",
      "Epoch 35/50\n",
      "333/333 [==============================] - 13s 38ms/step - loss: 7.1832e-04 - accuracy: 1.0000 - val_loss: 0.8762 - val_accuracy: 0.8588\n",
      "Epoch 36/50\n",
      "333/333 [==============================] - 19s 56ms/step - loss: 5.6931e-04 - accuracy: 1.0000 - val_loss: 0.8761 - val_accuracy: 0.8599\n",
      "Epoch 37/50\n",
      "333/333 [==============================] - 15s 43ms/step - loss: 5.6690e-04 - accuracy: 1.0000 - val_loss: 0.8845 - val_accuracy: 0.8588\n",
      "Epoch 38/50\n",
      "112/333 [=========>....................] - ETA: 8s - loss: 6.5745e-04 - accuracy: 0.9999"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "        \n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    \n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train Model\n",
    "    #\n",
    "    # #############################################################\n",
    "    \n",
    "    ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "    ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "    model, history = train_model(ds_train=ds_train,\n",
    "                                 ds_val=ds_val,\n",
    "                                 cfg=cfg,\n",
    "                                 exp_name=exp_name,\n",
    "                                 save=True,\n",
    "                                 checkpoint=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e62cc166",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "333/333 [==============================] - 6s 13ms/step - loss: 0.0037 - accuracy: 0.9989\n",
      "59/59 [==============================] - 1s 12ms/step - loss: 0.9230 - accuracy: 0.8593\n",
      "79/79 [==============================] - 1s 14ms/step - loss: 0.9221 - accuracy: 0.8553\n"
     ]
    }
   ],
   "source": [
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "ds_train = prefetch_dataset(ds_train, batch_size=cfg['batch_size'])\n",
    "ds_val = prefetch_dataset(ds_val, batch_size=cfg['batch_size'])\n",
    "ds_test = prefetch_dataset(ds_test, batch_size=cfg['batch_size'])\n",
    "\n",
    "##############################################################\n",
    "#\n",
    "# Compute the train, validation, test error\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "train_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "    train_acc.append(model.evaluate(ds_train, verbose=1)[1])\n",
    "    val_acc.append(model.evaluate(ds_val, verbose=1)[1])\n",
    "    test_acc.append(model.evaluate(ds_test, verbose=1)[1])\n",
    "print(f'Average train error: {(100-np.mean(train_acc)*100):.2f}, std: {(np.std(train_acc)*100):.2f}')\n",
    "print(f'Average validation error: {(100-np.mean(val_acc)*100):.2f}, std: {(np.std(val_acc)*100):.2f}')\n",
    "print(f'Average test error: {(100-np.mean(test_acc)*100):.2f}, std: {(np.std(test_acc)*100):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bbef3e3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
