{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-03-16T07:38:57.355211Z",
     "iopub.status.busy": "2025-03-16T07:38:57.354906Z",
     "iopub.status.idle": "2025-03-16T07:39:08.815150Z",
     "shell.execute_reply": "2025-03-16T07:39:08.814525Z",
     "shell.execute_reply.started": "2025-03-16T07:38:57.355152Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-16 07:38:57.977917: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting pyHSICLasso\n",
      "  Downloading pyHSICLasso-1.4.2-py2.py3-none-any.whl (14 kB)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.8.1)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: future in /usr/lib/python3/dist-packages (from pyHSICLasso) (0.18.2)\n",
      "Requirement already satisfied: seaborn in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (0.11.2)\n",
      "Collecting pytest\n",
      "  Downloading pytest-8.3.5-py3-none-any.whl (343 kB)\n",
      "     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 343.6/343.6 kB 32.9 MB/s eta 0:00:00\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.23.1)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.1.0)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (3.5.2)\n",
      "Requirement already satisfied: six in /usr/lib/python3/dist-packages (from pyHSICLasso) (1.14.0)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (9.2.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (4.34.4)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (0.11.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (2.8.2)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (3.0.9)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (21.3)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->pyHSICLasso) (2022.1)\n",
      "Collecting iniconfig\n",
      "  Downloading iniconfig-2.0.0-py3-none-any.whl (5.9 kB)\n",
      "Collecting tomli>=1\n",
      "  Downloading tomli-2.2.1-py3-none-any.whl (14 kB)\n",
      "Collecting exceptiongroup>=1.0.0rc8\n",
      "  Downloading exceptiongroup-1.2.2-py3-none-any.whl (16 kB)\n",
      "Collecting pluggy<2,>=1.5\n",
      "  Downloading pluggy-1.5.0-py3-none-any.whl (20 kB)\n",
      "Installing collected packages: tomli, pluggy, iniconfig, exceptiongroup, pytest, pyHSICLasso\n",
      "Successfully installed exceptiongroup-1.2.2 iniconfig-2.0.0 pluggy-1.5.0 pyHSICLasso-1.4.2 pytest-8.3.5 tomli-2.2.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting lassonet\n",
      "  Downloading lassonet-0.0.20-py3-none-any.whl (21 kB)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from lassonet) (3.5.2)\n",
      "Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.12.0+cu116)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from lassonet) (4.64.0)\n",
      "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.1.1)\n",
      "Collecting sortedcontainers\n",
      "  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.11->lassonet) (4.3.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (0.11.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (4.34.4)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (3.0.9)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.4.3)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.23.1)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (2.8.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (21.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (9.2.0)\n",
      "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.8.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (3.1.0)\n",
      "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.1.0)\n",
      "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->lassonet) (1.14.0)\n",
      "Installing collected packages: sortedcontainers, lassonet\n",
      "Successfully installed lassonet-0.0.20 sortedcontainers-2.4.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn, hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.2 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 26\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/ISOLET/'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    #1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    #4e-4,\n",
    "    5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    1.25e-1,\n",
    "    1.5e-1,\n",
    "    1.75e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T23:41:35.956579Z",
     "iopub.status.busy": "2025-01-19T23:41:35.955698Z",
     "iopub.status.idle": "2025-01-19T23:41:38.990918Z",
     "shell.execute_reply": "2025-01-19T23:41:38.990412Z",
     "shell.execute_reply.started": "2025-01-19T23:41:35.956553Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (6238, 617), y_train shape: (6238, 26)\n",
      "x_test shape: (1559, 617), y_test shape: (1559, 26)\n",
      "Normalized Training Set Mean and SD: [-8.81148312e-16 -1.21400165e-15 -1.65448681e-15  3.01490157e-15\n",
      " -8.34856390e-16  2.57835860e-16 -7.38250257e-17  2.11793107e-16\n",
      "  1.50515488e-16 -1.48157287e-16  2.56941524e-16  2.04318056e-17\n",
      " -1.10434977e-16  5.29411576e-16  1.91859638e-17 -5.79263046e-16\n",
      " -9.26100954e-16 -3.59016012e-16  1.20745208e-15 -2.97133270e-16\n",
      " -9.91387514e-16 -1.30350648e-16  3.26257492e-15 -1.86774823e-15\n",
      "  3.14810736e-16  3.03968492e-15  3.98924775e-15  2.07614197e-15\n",
      "  2.53973751e-16 -1.73193368e-15 -8.14139821e-16 -1.26768853e-16\n",
      "  6.39864350e-16  7.48508430e-16 -3.70916026e-16 -4.42398425e-16\n",
      " -1.95330197e-15  7.93636825e-16 -6.85106206e-16 -1.21738322e-15\n",
      "  4.05788473e-17 -6.62467480e-16 -3.01342436e-16 -1.76393402e-16\n",
      "  6.08148778e-17  2.37501942e-16  1.83329181e-15 -3.20572894e-16\n",
      "  2.81410747e-15 -5.78123990e-16  1.34352915e-15  3.05233912e-15\n",
      "  6.59953550e-16  3.42072564e-15  6.56159961e-15  4.98705135e-15\n",
      "  2.98904145e-16 -2.02761643e-15  6.22493756e-16  3.30504033e-16\n",
      " -2.07446899e-15  2.45114036e-15 -2.95531474e-16  6.48004592e-16\n",
      " -1.37633483e-15 -6.95900535e-16  6.17791371e-15 -5.18177642e-15\n",
      "  1.30735079e-15 -3.22495050e-17  7.57649794e-16 -6.54244924e-17\n",
      "  2.06204616e-16  7.37716325e-17  6.48549648e-16  5.37491750e-18\n",
      "  3.89236575e-17 -4.18015521e-16  1.15551827e-16 -3.40079217e-16\n",
      " -1.00361456e-15  1.28108133e-15 -1.03417328e-15  5.43382802e-15\n",
      " -5.12379138e-16  3.96613738e-15  1.82496247e-15  2.43035260e-15\n",
      " -6.59282575e-15  1.04751269e-15  8.16275550e-16  3.55047116e-15\n",
      " -1.90482093e-15 -4.72209639e-16 -3.28599675e-16  4.13002676e-16\n",
      "  2.72444245e-15  2.24091345e-16 -5.14327991e-16 -1.22078259e-15\n",
      " -1.17892230e-16  6.42889966e-16  2.86899570e-16  2.21047931e-17\n",
      " -1.80647062e-16  1.54982720e-16  7.47576273e-16  3.68448815e-16\n",
      "  1.52526632e-17 -3.61294123e-18 -1.85185485e-16 -2.20840588e-15\n",
      " -7.24047661e-16  1.40356538e-15 -3.19026270e-15 -2.32327249e-15\n",
      "  4.51288396e-16  4.77793680e-15 -4.96575746e-15  1.39657086e-15\n",
      " -1.93208706e-15  2.28540780e-16 -6.90267550e-16  2.18862369e-15\n",
      "  9.44134514e-16 -1.42136312e-15  6.41964484e-17  7.35171248e-16\n",
      " -1.73946212e-15 -3.64426525e-16 -3.10065108e-15 -1.27033149e-15\n",
      " -3.24417206e-16 -1.17358298e-16  1.32771141e-17  2.43152725e-16\n",
      "  6.32745254e-16  1.88068719e-16 -3.25645250e-16 -5.94569102e-16\n",
      "  2.59437657e-16 -3.24452801e-16  8.27683901e-17  4.02668863e-16\n",
      "  7.10361199e-16  1.12855470e-16  7.65899047e-16  1.40167882e-15\n",
      " -4.33392768e-16 -2.70863806e-16  1.41694038e-15 -4.10257486e-15\n",
      "  2.19310872e-15 -6.37782015e-16 -2.06663798e-15  1.89218453e-15\n",
      "  7.49427238e-16 -2.87682670e-16  1.88999541e-15  4.20444912e-16\n",
      " -1.32842332e-16  2.61982734e-16 -9.37620542e-16  1.11588983e-14\n",
      " -1.03251809e-15 -2.85475750e-17  7.23300156e-17  1.04483413e-15\n",
      "  6.88203012e-16  5.25816432e-16  4.85397765e-16 -1.66337679e-16\n",
      "  1.66355476e-16  6.20802971e-16 -8.06593579e-17  8.55848825e-17\n",
      "  5.34190269e-16 -9.71667618e-16 -6.98062960e-16 -4.91075244e-16\n",
      "  2.22995004e-15 -3.17689660e-17  3.22639211e-15  1.42775028e-15\n",
      "  6.60821189e-16 -3.83505703e-16 -6.23348048e-16 -2.33613136e-16\n",
      "  1.29247188e-16 -3.47233908e-16  2.87308918e-16  9.20321138e-17\n",
      "  2.85964743e-15  2.39344008e-16  6.03924484e-15 -2.33061406e-15\n",
      " -7.61921252e-17 -1.88656045e-17 -3.35060254e-16 -4.80503386e-16\n",
      " -4.44338378e-16 -8.47350404e-16  2.11116793e-16 -1.89720349e-15\n",
      "  1.27908798e-15  1.83494700e-16  3.23669701e-16 -7.88795839e-17\n",
      " -1.67298757e-18  1.08099914e-15  1.38146058e-16  1.95953118e-17\n",
      " -5.84433289e-17 -8.38482680e-16 -1.50902700e-15 -8.49691919e-16\n",
      " -3.89768282e-16 -4.07479259e-16  3.56666711e-16 -1.60339840e-16\n",
      "  1.38822372e-18  4.42451818e-17 -5.07947501e-17 -1.49785780e-16\n",
      " -4.39559685e-16 -1.38359631e-16  1.66294964e-15  2.21889765e-15\n",
      "  6.42596304e-15  1.47099211e-15 -2.88358984e-16 -1.14403873e-16\n",
      " -3.56880284e-16  4.21050035e-16  2.75936162e-16  5.37402761e-17\n",
      " -2.38240548e-16  7.77262904e-16  4.37966787e-16  3.43398940e-15\n",
      "  1.02949248e-15 -1.59270196e-15 -3.07966755e-15  1.54986280e-15\n",
      " -2.32539042e-15  1.96410520e-15  3.13453798e-16  2.70909635e-15\n",
      " -5.62381889e-16  1.84736982e-15 -5.11816285e-15 -1.20975689e-15\n",
      "  5.86684703e-16 -1.34967382e-15  3.47917342e-15  1.47977530e-15\n",
      " -6.72722538e-15 -1.05718576e-16  3.48874860e-15  1.05268293e-15\n",
      "  6.42729787e-15  5.74582240e-16 -3.29204798e-16  9.78875703e-17\n",
      " -4.85451158e-16 -6.95268715e-17 -2.17897731e-16  7.39340369e-17\n",
      "  6.36625162e-16  6.99664757e-16  1.68519681e-15  2.87079327e-15\n",
      "  1.07509029e-15  2.07316975e-16  3.08296013e-15  3.69310225e-15\n",
      " -3.74948883e-15 -1.54519979e-15 -1.42221741e-15 -2.52505437e-15\n",
      " -4.96296656e-16 -1.14304206e-15  2.91842113e-15  1.36042365e-15\n",
      "  5.47828677e-15  6.28035972e-15  4.28683486e-15  3.03469266e-15\n",
      "  4.68335071e-15  1.16317130e-16  5.29429373e-16 -4.20126333e-15\n",
      "  1.07711923e-15  3.82117479e-16 -8.11932901e-17 -1.01981051e-16\n",
      " -1.62778131e-16  1.23030438e-15  2.51749033e-16  9.91156144e-17\n",
      "  2.52727909e-17  2.66236393e-16 -1.30635412e-17  6.65644377e-16\n",
      "  1.97510420e-15 -1.69482539e-15  3.89987640e-15 -5.65968134e-17\n",
      " -4.51279497e-16 -4.17944330e-16  1.35745143e-15  1.74110842e-16\n",
      "  2.13658310e-15  1.35038573e-15  1.06736607e-15  8.86291859e-16\n",
      "  5.54221625e-16 -1.96508407e-15  1.94920849e-16 -5.75472127e-16\n",
      "  4.44556623e-15  1.94798935e-15  2.16360007e-15 -5.70018899e-15\n",
      " -4.39819532e-15 -1.76073932e-15 -5.09363756e-15  3.35838905e-15\n",
      "  1.82693802e-16 -1.24229293e-14  2.63758948e-15 -1.42850001e-15\n",
      "  2.56287457e-15 -5.74176451e-15 -9.82345372e-15  4.91698164e-16\n",
      "  1.38857078e-15  1.82165209e-15  1.20213055e-15  9.61077963e-17\n",
      " -1.49410248e-15  3.62044298e-15  1.82021270e-15 -1.01088383e-15\n",
      " -2.52259828e-15  7.02542652e-15  1.89386998e-14  1.24866631e-14\n",
      "  2.30739157e-14  1.55237050e-14 -1.03532124e-14 -5.28672969e-15\n",
      " -3.78047136e-15  5.17616123e-16 -9.02075785e-15 -2.31217560e-15\n",
      "  7.96128508e-16 -4.78282228e-15 -2.29896968e-15  5.41186560e-15\n",
      "  1.49145062e-16  2.78499036e-16 -1.98494635e-15 -1.61204810e-15\n",
      " -4.00346814e-16 -7.34299159e-16 -2.25468445e-15  4.20429117e-15\n",
      " -6.07131192e-15  3.43195601e-15  4.65446498e-16  7.03010732e-17\n",
      "  1.11126086e-15 -7.50609676e-16 -1.03199039e-14 -1.03906232e-14\n",
      "  2.48378319e-14  1.11200413e-14 -9.23358990e-15 -6.07817962e-15\n",
      " -7.38962167e-15 -1.08085675e-16  2.81818315e-15 -2.13334391e-15\n",
      "  2.84443482e-15  3.50731164e-15  4.30833008e-15  5.98288830e-15\n",
      "  7.31415925e-16  1.81294899e-15 -1.42915853e-15  3.39267640e-15\n",
      "  2.50649133e-15  3.98178160e-15  3.35359256e-15  1.38875766e-15\n",
      "  6.91548988e-16 -5.41791684e-15  8.67844502e-16 -1.45727896e-16\n",
      " -1.56412546e-15  1.14302426e-15  7.48263266e-15 -1.83743868e-16\n",
      "  5.69492086e-15  2.56429839e-15 -1.20746988e-15  1.71383694e-14\n",
      "  1.32910902e-15  3.42485471e-15  5.48553045e-15 -1.43031538e-15\n",
      "  7.51954517e-16  2.97800685e-15  3.18168419e-15  6.03218804e-16\n",
      "  6.23792991e-16  3.09894250e-16  4.04934182e-16  1.08637405e-15\n",
      "  1.90108340e-15  1.83302484e-15  1.17009462e-15  1.07138279e-15\n",
      " -2.14505482e-15 -9.93238479e-16 -1.84918519e-15  1.11711075e-15\n",
      " -6.88995012e-16  1.24785295e-15  1.87290068e-15 -1.30973569e-16\n",
      " -1.01094545e-14 -3.25892638e-15 -4.16420398e-15 -8.60161217e-15\n",
      " -5.40264638e-15 -3.85894159e-15 -4.04489238e-15  1.10393152e-15\n",
      " -7.30661746e-16 -1.15688870e-15 -4.56344734e-15 -8.76536918e-15\n",
      " -2.63611672e-15 -3.43503502e-15  2.28861139e-16  8.63457359e-16\n",
      "  9.08556832e-16  6.68791017e-15  1.00265348e-15  6.22901324e-15\n",
      "  1.94618732e-15 -1.49591785e-15 -2.90144988e-15  9.48699635e-16\n",
      "  2.24465098e-16  6.24148946e-16  1.75397618e-15  4.25161313e-16\n",
      " -8.27737294e-16 -7.48163599e-16  4.58367447e-16 -4.93086388e-16\n",
      "  2.47336028e-16 -1.01686387e-15 -3.85098600e-17  5.71098332e-16\n",
      "  3.38593106e-16 -6.97284309e-16 -7.95959430e-16  8.71181578e-16\n",
      "  1.10128856e-15  6.18827422e-17  4.06507502e-15 -1.51901932e-15\n",
      " -2.92488060e-16 -3.09460430e-15  1.77179172e-15 -1.85178366e-15\n",
      " -1.93536185e-15  3.34953468e-17  1.24363488e-15  7.05502416e-17\n",
      "  7.30419252e-17 -1.92856311e-16  5.02145438e-16 -4.91288816e-16\n",
      "  1.38516251e-15 -6.42569607e-16  6.88633718e-15  5.91361949e-15\n",
      "  6.13288765e-15  1.01393013e-14  3.32225964e-15  9.60152480e-16\n",
      " -5.94444518e-18  1.78520232e-15  8.65571730e-15  5.66541221e-15\n",
      " -5.04290066e-16 -1.35424339e-15  4.08878606e-16  1.05444213e-15\n",
      " -1.81821935e-15  1.48059400e-15  1.34833009e-15 -4.83969496e-15\n",
      " -2.74431363e-15 -2.54004007e-15  5.21358098e-16  1.52031855e-15\n",
      "  1.41630856e-15 -7.90219658e-18  1.39534282e-17  6.18257894e-16\n",
      "  1.87125439e-16  4.39034652e-16 -6.52927892e-16  8.79671100e-16\n",
      "  1.26634480e-15 -7.67438551e-17 -3.46607428e-15 -2.13085223e-15\n",
      "  3.16942155e-16  4.24798239e-15  1.04131018e-15  5.66188826e-15\n",
      "  3.58983976e-15 -3.25050805e-15  2.98112146e-15  2.78383351e-15\n",
      " -1.45464489e-15 -7.25916424e-16  3.71928273e-15 -2.26832197e-15\n",
      " -1.89697212e-15 -1.45138790e-15  2.26885590e-15 -2.92808419e-15\n",
      "  9.09375528e-16 -9.26265583e-16 -2.64492215e-16  1.72958438e-16\n",
      "  7.90219658e-18 -1.04034910e-15  2.35108146e-16  7.26930895e-16\n",
      " -2.32046935e-16 -5.63476450e-16 -4.59003716e-16  2.73693647e-15\n",
      "  5.07242711e-15 -2.09137684e-15  1.34137295e-14  3.18202235e-15\n",
      " -4.48826968e-15  2.50262922e-15  3.52324062e-15  3.13631775e-16\n",
      "  9.46355672e-15  4.96206332e-15  1.36978527e-15  5.52728395e-15\n",
      " -5.15319325e-15  6.88398788e-16  3.73394807e-15  1.58099995e-15\n",
      " -5.37669727e-17 -2.81996292e-16  1.97474825e-15  4.86412236e-16\n",
      " -1.08388237e-17 -5.82431043e-16  1.95170017e-16  7.26859704e-17\n",
      " -8.59506260e-16  2.14427172e-16 -7.65445204e-16 -1.65874937e-17\n",
      "  2.91609564e-14  5.28494992e-15  2.23198966e-14 -4.32250153e-15\n",
      " -2.50930337e-16 -1.31073681e-15 -2.84443482e-16  6.02061951e-16\n",
      " -2.66681337e-16 -2.37350661e-16  7.75411939e-16 -1.39534282e-15\n",
      "  5.29020025e-16  8.64543021e-16  1.44019313e-16 -1.97006744e-15\n",
      " -4.94542243e-15  1.17536275e-15  8.24284533e-15  3.87382050e-15\n",
      "  2.06464463e-15  1.82028166e-15 -1.28734613e-15 -5.52555757e-15\n",
      "  1.03824896e-15  5.64145645e-15  6.33012221e-15 -2.54845840e-15\n",
      "  4.92819422e-16 -9.73038044e-16  3.48141593e-16 -2.76897240e-16\n",
      " -1.64664691e-15  2.33542167e-15  5.83605694e-16  6.27557213e-16\n",
      " -1.26166399e-15] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (6238, 617)\n",
      "Train labels shape:  (6238, 26)\n",
      "Test data shape:  (1559, 617)\n",
      "Test labels shape:  (1559, 26)\n"
     ]
    }
   ],
   "source": [
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def load_isolet(one_hot=True):\n",
    "    isolet_base_path = os.path.join('data', 'isolet')\n",
    "    isolet1_4_path = os.path.join(isolet_base_path, 'isolet1+2+3+4.data')\n",
    "    isolet5_path = os.path.join(isolet_base_path, 'isolet5.data')\n",
    "\n",
    "    x_train = np.genfromtxt(isolet1_4_path, delimiter=',', usecols=range(0, 617), encoding='UTF-8')\n",
    "    y_train = np.genfromtxt(isolet1_4_path, delimiter=',', usecols=[617], encoding='UTF-8') - 1\n",
    "    x_test = np.genfromtxt(isolet5_path, delimiter=',', usecols=range(0, 617), encoding='UTF-8')\n",
    "    y_test = np.genfromtxt(isolet5_path, delimiter=',', usecols=[617], encoding='UTF-8') - 1\n",
    "\n",
    "    # Ensure y_train and y_test are integers\n",
    "    y_train = y_train.astype(int)\n",
    "    y_test = y_test.astype(int)\n",
    "\n",
    "    X = np.concatenate((x_train, x_test))\n",
    "    x_train = X[:len(y_train)]\n",
    "    x_test = X[len(y_train):]\n",
    "\n",
    "    # normalize data\n",
    "    scaler = StandardScaler().fit(x_train)\n",
    "    x_train = scaler.transform(x_train)\n",
    "    x_test = scaler.transform(x_test)\n",
    "\n",
    "    num_classes = np.unique(y_train).shape[0]\n",
    "\n",
    "    if one_hot:\n",
    "        y_train = one_hot_encode(y_train, num_classes)\n",
    "        y_test = one_hot_encode(y_test, num_classes)\n",
    "\n",
    "    print(\"x_train shape: {}, y_train shape: {}\".format(x_train.shape, y_train.shape))\n",
    "    print(\"x_test shape: {}, y_test shape: {}\".format(x_test.shape, y_test.shape))\n",
    "\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_isolet()\n",
    "\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-19T10:38:35.894674Z",
     "iopub.status.busy": "2025-01-19T10:38:35.894044Z",
     "iopub.status.idle": "2025-01-19T10:38:46.434974Z",
     "shell.execute_reply": "2025-01-19T10:38:46.434501Z",
     "shell.execute_reply.started": "2025-01-19T10:38:35.894643Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 617)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               185400    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 26)                2626      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 218,126\n",
      "Trainable params: 218,126\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 1.000e-01\n",
      "Epoch 1/100\n",
      "25/25 [==============================] - 3s 2ms/step - loss: 0.9322 - accuracy: 0.7304\n",
      "\n",
      "Epoch 2: Current learning rate = 9.997e-02\n",
      "Epoch 2/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.1282 - accuracy: 0.9598\n",
      "\n",
      "Epoch 3: Current learning rate = 9.989e-02\n",
      "Epoch 3/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0517 - accuracy: 0.9845\n",
      "\n",
      "Epoch 4: Current learning rate = 9.976e-02\n",
      "Epoch 4/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0193 - accuracy: 0.9954\n",
      "\n",
      "Epoch 5: Current learning rate = 9.957e-02\n",
      "Epoch 5/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0092 - accuracy: 0.9989\n",
      "\n",
      "Epoch 6: Current learning rate = 9.933e-02\n",
      "Epoch 6/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0045 - accuracy: 0.9998\n",
      "\n",
      "Epoch 7: Current learning rate = 9.904e-02\n",
      "Epoch 7/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0028 - accuracy: 1.0000\n",
      "\n",
      "Epoch 8: Current learning rate = 9.869e-02\n",
      "Epoch 8/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0020 - accuracy: 1.0000\n",
      "\n",
      "Epoch 9: Current learning rate = 9.830e-02\n",
      "Epoch 9/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 1.0000\n",
      "\n",
      "Epoch 10: Current learning rate = 9.785e-02\n",
      "Epoch 10/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0015 - accuracy: 1.0000\n",
      "\n",
      "Epoch 11: Current learning rate = 9.735e-02\n",
      "Epoch 11/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0012 - accuracy: 1.0000\n",
      "\n",
      "Epoch 12: Current learning rate = 9.680e-02\n",
      "Epoch 12/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 1.0000\n",
      "\n",
      "Epoch 13: Current learning rate = 9.619e-02\n",
      "Epoch 13/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 9.5528e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 14: Current learning rate = 9.554e-02\n",
      "Epoch 14/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 8.7712e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 15: Current learning rate = 9.484e-02\n",
      "Epoch 15/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 8.0768e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 16: Current learning rate = 9.410e-02\n",
      "Epoch 16/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 7.5614e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 17: Current learning rate = 9.330e-02\n",
      "Epoch 17/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 7.1167e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 18: Current learning rate = 9.246e-02\n",
      "Epoch 18/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 6.7209e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 19: Current learning rate = 9.157e-02\n",
      "Epoch 19/100\n",
      "25/25 [==============================] - 0s 3ms/step - loss: 6.3248e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 20: Current learning rate = 9.064e-02\n",
      "Epoch 20/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 5.9585e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 21: Current learning rate = 8.967e-02\n",
      "Epoch 21/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 5.7244e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 22: Current learning rate = 8.865e-02\n",
      "Epoch 22/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 5.4675e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 23: Current learning rate = 8.759e-02\n",
      "Epoch 23/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 5.2157e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 24: Current learning rate = 8.649e-02\n",
      "Epoch 24/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.9839e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 25: Current learning rate = 8.536e-02\n",
      "Epoch 25/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.8010e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 26: Current learning rate = 8.418e-02\n",
      "Epoch 26/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.6201e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 27: Current learning rate = 8.297e-02\n",
      "Epoch 27/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.4664e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 28: Current learning rate = 8.172e-02\n",
      "Epoch 28/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.3057e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 29: Current learning rate = 8.044e-02\n",
      "Epoch 29/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.1539e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 30: Current learning rate = 7.912e-02\n",
      "Epoch 30/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 4.0279e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 31: Current learning rate = 7.778e-02\n",
      "Epoch 31/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.9137e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 32: Current learning rate = 7.640e-02\n",
      "Epoch 32/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.8166e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 33: Current learning rate = 7.500e-02\n",
      "Epoch 33/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.7061e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 34: Current learning rate = 7.357e-02\n",
      "Epoch 34/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.6005e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 35: Current learning rate = 7.211e-02\n",
      "Epoch 35/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.5133e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 36: Current learning rate = 7.064e-02\n",
      "Epoch 36/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.4356e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 37: Current learning rate = 6.913e-02\n",
      "Epoch 37/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.3504e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 38: Current learning rate = 6.761e-02\n",
      "Epoch 38/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.2755e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 39: Current learning rate = 6.607e-02\n",
      "Epoch 39/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.2052e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 40: Current learning rate = 6.451e-02\n",
      "Epoch 40/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.1398e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 41: Current learning rate = 6.294e-02\n",
      "Epoch 41/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.0814e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 42: Current learning rate = 6.135e-02\n",
      "Epoch 42/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 3.0225e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 43: Current learning rate = 5.975e-02\n",
      "Epoch 43/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.9704e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 44: Current learning rate = 5.814e-02\n",
      "Epoch 44/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.9204e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 45: Current learning rate = 5.653e-02\n",
      "Epoch 45/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.8714e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 46: Current learning rate = 5.490e-02\n",
      "Epoch 46/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.8303e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 47: Current learning rate = 5.327e-02\n",
      "Epoch 47/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.7856e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 48: Current learning rate = 5.164e-02\n",
      "Epoch 48/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.7440e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 49: Current learning rate = 5.000e-02\n",
      "Epoch 49/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.7087e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 50: Current learning rate = 4.836e-02\n",
      "Epoch 50/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.6706e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 51: Current learning rate = 4.673e-02\n",
      "Epoch 51/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.6388e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 52: Current learning rate = 4.510e-02\n",
      "Epoch 52/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.6060e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 53: Current learning rate = 4.347e-02\n",
      "Epoch 53/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.5799e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 54: Current learning rate = 4.186e-02\n",
      "Epoch 54/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.5479e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 55: Current learning rate = 4.025e-02\n",
      "Epoch 55/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.5241e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 56: Current learning rate = 3.865e-02\n",
      "Epoch 56/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.4950e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 57: Current learning rate = 3.706e-02\n",
      "Epoch 57/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.4730e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 58: Current learning rate = 3.549e-02\n",
      "Epoch 58/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.4527e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 59: Current learning rate = 3.393e-02\n",
      "Epoch 59/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.4322e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 60: Current learning rate = 3.239e-02\n",
      "Epoch 60/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.4112e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 61: Current learning rate = 3.087e-02\n",
      "Epoch 61/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3951e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 62: Current learning rate = 2.936e-02\n",
      "Epoch 62/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3756e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 63: Current learning rate = 2.789e-02\n",
      "Epoch 63/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3591e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 64: Current learning rate = 2.643e-02\n",
      "Epoch 64/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3451e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 65: Current learning rate = 2.500e-02\n",
      "Epoch 65/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3297e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 66: Current learning rate = 2.360e-02\n",
      "Epoch 66/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3186e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 67: Current learning rate = 2.222e-02\n",
      "Epoch 67/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.3052e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 68: Current learning rate = 2.088e-02\n",
      "Epoch 68/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2935e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 69: Current learning rate = 1.956e-02\n",
      "Epoch 69/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2823e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 70: Current learning rate = 1.828e-02\n",
      "Epoch 70/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2724e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 71: Current learning rate = 1.703e-02\n",
      "Epoch 71/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2635e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 72: Current learning rate = 1.582e-02\n",
      "Epoch 72/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2551e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 73: Current learning rate = 1.464e-02\n",
      "Epoch 73/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2467e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 74: Current learning rate = 1.351e-02\n",
      "Epoch 74/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2392e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 75: Current learning rate = 1.241e-02\n",
      "Epoch 75/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2326e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 76: Current learning rate = 1.135e-02\n",
      "Epoch 76/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2269e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 77: Current learning rate = 1.033e-02\n",
      "Epoch 77/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2205e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 78: Current learning rate = 9.358e-03\n",
      "Epoch 78/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2159e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 79: Current learning rate = 8.427e-03\n",
      "Epoch 79/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2110e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 80: Current learning rate = 7.540e-03\n",
      "Epoch 80/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2072e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 81: Current learning rate = 6.699e-03\n",
      "Epoch 81/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2034e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 82: Current learning rate = 5.904e-03\n",
      "Epoch 82/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.2003e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 83: Current learning rate = 5.156e-03\n",
      "Epoch 83/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1975e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 84: Current learning rate = 4.457e-03\n",
      "Epoch 84/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1950e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 85: Current learning rate = 3.806e-03\n",
      "Epoch 85/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1931e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 86: Current learning rate = 3.205e-03\n",
      "Epoch 86/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1912e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 87: Current learning rate = 2.653e-03\n",
      "Epoch 87/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1898e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 88: Current learning rate = 2.153e-03\n",
      "Epoch 88/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1885e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 89: Current learning rate = 1.704e-03\n",
      "Epoch 89/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1875e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 90: Current learning rate = 1.306e-03\n",
      "Epoch 90/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1867e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 91: Current learning rate = 9.607e-04\n",
      "Epoch 91/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1861e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 92: Current learning rate = 6.678e-04\n",
      "Epoch 92/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1856e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 93: Current learning rate = 4.278e-04\n",
      "Epoch 93/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1854e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 94: Current learning rate = 2.408e-04\n",
      "Epoch 94/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1852e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 95: Current learning rate = 1.071e-04\n",
      "Epoch 95/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1851e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 96: Current learning rate = 2.677e-05\n",
      "Epoch 96/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1850e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1851e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1850e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1851e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "25/25 [==============================] - 0s 2ms/step - loss: 2.1850e-04 - accuracy: 1.0000\n",
      "49/49 [==============================] - 0s 1ms/step - loss: 0.1933 - accuracy: 0.9564\n",
      "\n",
      "Test loss 0.19330178201198578\n",
      "Test accuracy 0.9563822746276855\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on ISOLET\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_isolet'\n",
    "#DEPTH = DEPTH\n",
    "LA = 0 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = 0.1\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T10:38:58.369079Z",
     "iopub.status.busy": "2025-01-19T10:38:58.368412Z",
     "iopub.status.idle": "2025-01-19T11:46:56.023282Z",
     "shell.execute_reply": "2025-01-19T11:46:56.017563Z",
     "shell.execute_reply.started": "2025-01-19T10:38:58.369055Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED  \n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_isolet'\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "\n",
    "            # Check if results already exist\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            #if os.path.exists(pretrain_csv_file_path):\n",
    "            #    print(f'Results already exist for depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}. Skipping...')\n",
    "            #    continue\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                              dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                              large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                   loss='categorical_crossentropy',\n",
    "                   metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                   callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                            'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T23:42:03.424006Z",
     "iopub.status.busy": "2025-01-19T23:42:03.423373Z",
     "iopub.status.idle": "2025-01-19T23:42:06.244487Z",
     "shell.execute_reply": "2025-01-19T23:42:06.243929Z",
     "shell.execute_reply.started": "2025-01-19T23:42:03.423982Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (6238, 617), y_train shape: (6238,)\n",
      "x_test shape: (1559, 617), y_test shape: (1559,)\n",
      "Normalized Training Set Mean and SD: [-8.81148312e-16 -1.21400165e-15 -1.65448681e-15  3.01490157e-15\n",
      " -8.34856390e-16  2.57835860e-16 -7.38250257e-17  2.11793107e-16\n",
      "  1.50515488e-16 -1.48157287e-16  2.56941524e-16  2.04318056e-17\n",
      " -1.10434977e-16  5.29411576e-16  1.91859638e-17 -5.79263046e-16\n",
      " -9.26100954e-16 -3.59016012e-16  1.20745208e-15 -2.97133270e-16\n",
      " -9.91387514e-16 -1.30350648e-16  3.26257492e-15 -1.86774823e-15\n",
      "  3.14810736e-16  3.03968492e-15  3.98924775e-15  2.07614197e-15\n",
      "  2.53973751e-16 -1.73193368e-15 -8.14139821e-16 -1.26768853e-16\n",
      "  6.39864350e-16  7.48508430e-16 -3.70916026e-16 -4.42398425e-16\n",
      " -1.95330197e-15  7.93636825e-16 -6.85106206e-16 -1.21738322e-15\n",
      "  4.05788473e-17 -6.62467480e-16 -3.01342436e-16 -1.76393402e-16\n",
      "  6.08148778e-17  2.37501942e-16  1.83329181e-15 -3.20572894e-16\n",
      "  2.81410747e-15 -5.78123990e-16  1.34352915e-15  3.05233912e-15\n",
      "  6.59953550e-16  3.42072564e-15  6.56159961e-15  4.98705135e-15\n",
      "  2.98904145e-16 -2.02761643e-15  6.22493756e-16  3.30504033e-16\n",
      " -2.07446899e-15  2.45114036e-15 -2.95531474e-16  6.48004592e-16\n",
      " -1.37633483e-15 -6.95900535e-16  6.17791371e-15 -5.18177642e-15\n",
      "  1.30735079e-15 -3.22495050e-17  7.57649794e-16 -6.54244924e-17\n",
      "  2.06204616e-16  7.37716325e-17  6.48549648e-16  5.37491750e-18\n",
      "  3.89236575e-17 -4.18015521e-16  1.15551827e-16 -3.40079217e-16\n",
      " -1.00361456e-15  1.28108133e-15 -1.03417328e-15  5.43382802e-15\n",
      " -5.12379138e-16  3.96613738e-15  1.82496247e-15  2.43035260e-15\n",
      " -6.59282575e-15  1.04751269e-15  8.16275550e-16  3.55047116e-15\n",
      " -1.90482093e-15 -4.72209639e-16 -3.28599675e-16  4.13002676e-16\n",
      "  2.72444245e-15  2.24091345e-16 -5.14327991e-16 -1.22078259e-15\n",
      " -1.17892230e-16  6.42889966e-16  2.86899570e-16  2.21047931e-17\n",
      " -1.80647062e-16  1.54982720e-16  7.47576273e-16  3.68448815e-16\n",
      "  1.52526632e-17 -3.61294123e-18 -1.85185485e-16 -2.20840588e-15\n",
      " -7.24047661e-16  1.40356538e-15 -3.19026270e-15 -2.32327249e-15\n",
      "  4.51288396e-16  4.77793680e-15 -4.96575746e-15  1.39657086e-15\n",
      " -1.93208706e-15  2.28540780e-16 -6.90267550e-16  2.18862369e-15\n",
      "  9.44134514e-16 -1.42136312e-15  6.41964484e-17  7.35171248e-16\n",
      " -1.73946212e-15 -3.64426525e-16 -3.10065108e-15 -1.27033149e-15\n",
      " -3.24417206e-16 -1.17358298e-16  1.32771141e-17  2.43152725e-16\n",
      "  6.32745254e-16  1.88068719e-16 -3.25645250e-16 -5.94569102e-16\n",
      "  2.59437657e-16 -3.24452801e-16  8.27683901e-17  4.02668863e-16\n",
      "  7.10361199e-16  1.12855470e-16  7.65899047e-16  1.40167882e-15\n",
      " -4.33392768e-16 -2.70863806e-16  1.41694038e-15 -4.10257486e-15\n",
      "  2.19310872e-15 -6.37782015e-16 -2.06663798e-15  1.89218453e-15\n",
      "  7.49427238e-16 -2.87682670e-16  1.88999541e-15  4.20444912e-16\n",
      " -1.32842332e-16  2.61982734e-16 -9.37620542e-16  1.11588983e-14\n",
      " -1.03251809e-15 -2.85475750e-17  7.23300156e-17  1.04483413e-15\n",
      "  6.88203012e-16  5.25816432e-16  4.85397765e-16 -1.66337679e-16\n",
      "  1.66355476e-16  6.20802971e-16 -8.06593579e-17  8.55848825e-17\n",
      "  5.34190269e-16 -9.71667618e-16 -6.98062960e-16 -4.91075244e-16\n",
      "  2.22995004e-15 -3.17689660e-17  3.22639211e-15  1.42775028e-15\n",
      "  6.60821189e-16 -3.83505703e-16 -6.23348048e-16 -2.33613136e-16\n",
      "  1.29247188e-16 -3.47233908e-16  2.87308918e-16  9.20321138e-17\n",
      "  2.85964743e-15  2.39344008e-16  6.03924484e-15 -2.33061406e-15\n",
      " -7.61921252e-17 -1.88656045e-17 -3.35060254e-16 -4.80503386e-16\n",
      " -4.44338378e-16 -8.47350404e-16  2.11116793e-16 -1.89720349e-15\n",
      "  1.27908798e-15  1.83494700e-16  3.23669701e-16 -7.88795839e-17\n",
      " -1.67298757e-18  1.08099914e-15  1.38146058e-16  1.95953118e-17\n",
      " -5.84433289e-17 -8.38482680e-16 -1.50902700e-15 -8.49691919e-16\n",
      " -3.89768282e-16 -4.07479259e-16  3.56666711e-16 -1.60339840e-16\n",
      "  1.38822372e-18  4.42451818e-17 -5.07947501e-17 -1.49785780e-16\n",
      " -4.39559685e-16 -1.38359631e-16  1.66294964e-15  2.21889765e-15\n",
      "  6.42596304e-15  1.47099211e-15 -2.88358984e-16 -1.14403873e-16\n",
      " -3.56880284e-16  4.21050035e-16  2.75936162e-16  5.37402761e-17\n",
      " -2.38240548e-16  7.77262904e-16  4.37966787e-16  3.43398940e-15\n",
      "  1.02949248e-15 -1.59270196e-15 -3.07966755e-15  1.54986280e-15\n",
      " -2.32539042e-15  1.96410520e-15  3.13453798e-16  2.70909635e-15\n",
      " -5.62381889e-16  1.84736982e-15 -5.11816285e-15 -1.20975689e-15\n",
      "  5.86684703e-16 -1.34967382e-15  3.47917342e-15  1.47977530e-15\n",
      " -6.72722538e-15 -1.05718576e-16  3.48874860e-15  1.05268293e-15\n",
      "  6.42729787e-15  5.74582240e-16 -3.29204798e-16  9.78875703e-17\n",
      " -4.85451158e-16 -6.95268715e-17 -2.17897731e-16  7.39340369e-17\n",
      "  6.36625162e-16  6.99664757e-16  1.68519681e-15  2.87079327e-15\n",
      "  1.07509029e-15  2.07316975e-16  3.08296013e-15  3.69310225e-15\n",
      " -3.74948883e-15 -1.54519979e-15 -1.42221741e-15 -2.52505437e-15\n",
      " -4.96296656e-16 -1.14304206e-15  2.91842113e-15  1.36042365e-15\n",
      "  5.47828677e-15  6.28035972e-15  4.28683486e-15  3.03469266e-15\n",
      "  4.68335071e-15  1.16317130e-16  5.29429373e-16 -4.20126333e-15\n",
      "  1.07711923e-15  3.82117479e-16 -8.11932901e-17 -1.01981051e-16\n",
      " -1.62778131e-16  1.23030438e-15  2.51749033e-16  9.91156144e-17\n",
      "  2.52727909e-17  2.66236393e-16 -1.30635412e-17  6.65644377e-16\n",
      "  1.97510420e-15 -1.69482539e-15  3.89987640e-15 -5.65968134e-17\n",
      " -4.51279497e-16 -4.17944330e-16  1.35745143e-15  1.74110842e-16\n",
      "  2.13658310e-15  1.35038573e-15  1.06736607e-15  8.86291859e-16\n",
      "  5.54221625e-16 -1.96508407e-15  1.94920849e-16 -5.75472127e-16\n",
      "  4.44556623e-15  1.94798935e-15  2.16360007e-15 -5.70018899e-15\n",
      " -4.39819532e-15 -1.76073932e-15 -5.09363756e-15  3.35838905e-15\n",
      "  1.82693802e-16 -1.24229293e-14  2.63758948e-15 -1.42850001e-15\n",
      "  2.56287457e-15 -5.74176451e-15 -9.82345372e-15  4.91698164e-16\n",
      "  1.38857078e-15  1.82165209e-15  1.20213055e-15  9.61077963e-17\n",
      " -1.49410248e-15  3.62044298e-15  1.82021270e-15 -1.01088383e-15\n",
      " -2.52259828e-15  7.02542652e-15  1.89386998e-14  1.24866631e-14\n",
      "  2.30739157e-14  1.55237050e-14 -1.03532124e-14 -5.28672969e-15\n",
      " -3.78047136e-15  5.17616123e-16 -9.02075785e-15 -2.31217560e-15\n",
      "  7.96128508e-16 -4.78282228e-15 -2.29896968e-15  5.41186560e-15\n",
      "  1.49145062e-16  2.78499036e-16 -1.98494635e-15 -1.61204810e-15\n",
      " -4.00346814e-16 -7.34299159e-16 -2.25468445e-15  4.20429117e-15\n",
      " -6.07131192e-15  3.43195601e-15  4.65446498e-16  7.03010732e-17\n",
      "  1.11126086e-15 -7.50609676e-16 -1.03199039e-14 -1.03906232e-14\n",
      "  2.48378319e-14  1.11200413e-14 -9.23358990e-15 -6.07817962e-15\n",
      " -7.38962167e-15 -1.08085675e-16  2.81818315e-15 -2.13334391e-15\n",
      "  2.84443482e-15  3.50731164e-15  4.30833008e-15  5.98288830e-15\n",
      "  7.31415925e-16  1.81294899e-15 -1.42915853e-15  3.39267640e-15\n",
      "  2.50649133e-15  3.98178160e-15  3.35359256e-15  1.38875766e-15\n",
      "  6.91548988e-16 -5.41791684e-15  8.67844502e-16 -1.45727896e-16\n",
      " -1.56412546e-15  1.14302426e-15  7.48263266e-15 -1.83743868e-16\n",
      "  5.69492086e-15  2.56429839e-15 -1.20746988e-15  1.71383694e-14\n",
      "  1.32910902e-15  3.42485471e-15  5.48553045e-15 -1.43031538e-15\n",
      "  7.51954517e-16  2.97800685e-15  3.18168419e-15  6.03218804e-16\n",
      "  6.23792991e-16  3.09894250e-16  4.04934182e-16  1.08637405e-15\n",
      "  1.90108340e-15  1.83302484e-15  1.17009462e-15  1.07138279e-15\n",
      " -2.14505482e-15 -9.93238479e-16 -1.84918519e-15  1.11711075e-15\n",
      " -6.88995012e-16  1.24785295e-15  1.87290068e-15 -1.30973569e-16\n",
      " -1.01094545e-14 -3.25892638e-15 -4.16420398e-15 -8.60161217e-15\n",
      " -5.40264638e-15 -3.85894159e-15 -4.04489238e-15  1.10393152e-15\n",
      " -7.30661746e-16 -1.15688870e-15 -4.56344734e-15 -8.76536918e-15\n",
      " -2.63611672e-15 -3.43503502e-15  2.28861139e-16  8.63457359e-16\n",
      "  9.08556832e-16  6.68791017e-15  1.00265348e-15  6.22901324e-15\n",
      "  1.94618732e-15 -1.49591785e-15 -2.90144988e-15  9.48699635e-16\n",
      "  2.24465098e-16  6.24148946e-16  1.75397618e-15  4.25161313e-16\n",
      " -8.27737294e-16 -7.48163599e-16  4.58367447e-16 -4.93086388e-16\n",
      "  2.47336028e-16 -1.01686387e-15 -3.85098600e-17  5.71098332e-16\n",
      "  3.38593106e-16 -6.97284309e-16 -7.95959430e-16  8.71181578e-16\n",
      "  1.10128856e-15  6.18827422e-17  4.06507502e-15 -1.51901932e-15\n",
      " -2.92488060e-16 -3.09460430e-15  1.77179172e-15 -1.85178366e-15\n",
      " -1.93536185e-15  3.34953468e-17  1.24363488e-15  7.05502416e-17\n",
      "  7.30419252e-17 -1.92856311e-16  5.02145438e-16 -4.91288816e-16\n",
      "  1.38516251e-15 -6.42569607e-16  6.88633718e-15  5.91361949e-15\n",
      "  6.13288765e-15  1.01393013e-14  3.32225964e-15  9.60152480e-16\n",
      " -5.94444518e-18  1.78520232e-15  8.65571730e-15  5.66541221e-15\n",
      " -5.04290066e-16 -1.35424339e-15  4.08878606e-16  1.05444213e-15\n",
      " -1.81821935e-15  1.48059400e-15  1.34833009e-15 -4.83969496e-15\n",
      " -2.74431363e-15 -2.54004007e-15  5.21358098e-16  1.52031855e-15\n",
      "  1.41630856e-15 -7.90219658e-18  1.39534282e-17  6.18257894e-16\n",
      "  1.87125439e-16  4.39034652e-16 -6.52927892e-16  8.79671100e-16\n",
      "  1.26634480e-15 -7.67438551e-17 -3.46607428e-15 -2.13085223e-15\n",
      "  3.16942155e-16  4.24798239e-15  1.04131018e-15  5.66188826e-15\n",
      "  3.58983976e-15 -3.25050805e-15  2.98112146e-15  2.78383351e-15\n",
      " -1.45464489e-15 -7.25916424e-16  3.71928273e-15 -2.26832197e-15\n",
      " -1.89697212e-15 -1.45138790e-15  2.26885590e-15 -2.92808419e-15\n",
      "  9.09375528e-16 -9.26265583e-16 -2.64492215e-16  1.72958438e-16\n",
      "  7.90219658e-18 -1.04034910e-15  2.35108146e-16  7.26930895e-16\n",
      " -2.32046935e-16 -5.63476450e-16 -4.59003716e-16  2.73693647e-15\n",
      "  5.07242711e-15 -2.09137684e-15  1.34137295e-14  3.18202235e-15\n",
      " -4.48826968e-15  2.50262922e-15  3.52324062e-15  3.13631775e-16\n",
      "  9.46355672e-15  4.96206332e-15  1.36978527e-15  5.52728395e-15\n",
      " -5.15319325e-15  6.88398788e-16  3.73394807e-15  1.58099995e-15\n",
      " -5.37669727e-17 -2.81996292e-16  1.97474825e-15  4.86412236e-16\n",
      " -1.08388237e-17 -5.82431043e-16  1.95170017e-16  7.26859704e-17\n",
      " -8.59506260e-16  2.14427172e-16 -7.65445204e-16 -1.65874937e-17\n",
      "  2.91609564e-14  5.28494992e-15  2.23198966e-14 -4.32250153e-15\n",
      " -2.50930337e-16 -1.31073681e-15 -2.84443482e-16  6.02061951e-16\n",
      " -2.66681337e-16 -2.37350661e-16  7.75411939e-16 -1.39534282e-15\n",
      "  5.29020025e-16  8.64543021e-16  1.44019313e-16 -1.97006744e-15\n",
      " -4.94542243e-15  1.17536275e-15  8.24284533e-15  3.87382050e-15\n",
      "  2.06464463e-15  1.82028166e-15 -1.28734613e-15 -5.52555757e-15\n",
      "  1.03824896e-15  5.64145645e-15  6.33012221e-15 -2.54845840e-15\n",
      "  4.92819422e-16 -9.73038044e-16  3.48141593e-16 -2.76897240e-16\n",
      " -1.64664691e-15  2.33542167e-15  5.83605694e-16  6.27557213e-16\n",
      " -1.26166399e-15] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (6238, 617)\n",
      "Train labels shape:  (6238,)\n",
      "Test data shape:  (1559, 617)\n",
      "Test labels shape:  (1559,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_isolet(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T23:42:12.413236Z",
     "iopub.status.busy": "2025-01-19T23:42:12.412794Z",
     "iopub.status.idle": "2025-01-20T04:18:40.033742Z",
     "shell.execute_reply": "2025-01-20T04:18:40.032815Z",
     "shell.execute_reply.started": "2025-01-19T23:42:12.413212Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_dnn, hsic_svm, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/ISOLET/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False),\n",
    "        (\"LassoNet\", lassoNet, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        # Check if results already exist\n",
    "        if os.path.exists(result_filename):\n",
    "            print(f\"Results already exist for {method_name} on {dataset_name} (repetition {rep + 1}). Skipping...\")\n",
    "            continue\n",
    "            \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"ISOLET\": load_isolet\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'ISOLET')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
