{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T07:33:26.087252Z",
     "iopub.status.busy": "2025-01-20T07:33:26.086629Z",
     "iopub.status.idle": "2025-01-20T07:33:38.120397Z",
     "shell.execute_reply": "2025-01-20T07:33:38.119678Z",
     "shell.execute_reply.started": "2025-01-20T07:33:26.087192Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting pyHSICLasso\n",
      "  Downloading pyHSICLasso-1.4.2-py2.py3-none-any.whl (14 kB)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.23.1)\n",
      "Requirement already satisfied: six in /usr/lib/python3/dist-packages (from pyHSICLasso) (1.14.0)\n",
      "Requirement already satisfied: future in /usr/lib/python3/dist-packages (from pyHSICLasso) (0.18.2)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (3.5.2)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.8.1)\n",
      "Collecting pytest\n",
      "  Downloading pytest-8.3.4-py3-none-any.whl (343 kB)\n",
      "     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 343.1/343.1 kB 38.9 MB/s eta 0:00:00\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.1.0)\n",
      "Requirement already satisfied: seaborn in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (0.11.2)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (21.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (0.11.0)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (2.8.2)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (9.2.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (4.34.4)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->pyHSICLasso) (2022.1)\n",
      "Collecting exceptiongroup>=1.0.0rc8\n",
      "  Downloading exceptiongroup-1.2.2-py3-none-any.whl (16 kB)\n",
      "Collecting iniconfig\n",
      "  Downloading iniconfig-2.0.0-py3-none-any.whl (5.9 kB)\n",
      "Collecting pluggy<2,>=1.5\n",
      "  Downloading pluggy-1.5.0-py3-none-any.whl (20 kB)\n",
      "Collecting tomli>=1\n",
      "  Downloading tomli-2.2.1-py3-none-any.whl (14 kB)\n",
      "Installing collected packages: tomli, pluggy, iniconfig, exceptiongroup, pytest, pyHSICLasso\n",
      "Successfully installed exceptiongroup-1.2.2 iniconfig-2.0.0 pluggy-1.5.0 pyHSICLasso-1.4.2 pytest-8.3.4 tomli-2.2.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting lassonet\n",
      "  Downloading lassonet-0.0.20-py3-none-any.whl (21 kB)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from lassonet) (3.5.2)\n",
      "Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.12.0+cu116)\n",
      "Collecting sortedcontainers\n",
      "  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)\n",
      "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.1.1)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from lassonet) (4.64.0)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.11->lassonet) (4.3.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.4.3)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.23.1)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (3.0.9)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (2.8.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (21.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (0.11.0)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (9.2.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (4.34.4)\n",
      "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.1.0)\n",
      "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.8.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (3.1.0)\n",
      "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->lassonet) (1.14.0)\n",
      "Installing collected packages: sortedcontainers, lassonet\n",
      "Successfully installed lassonet-0.0.20 sortedcontainers-2.4.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn, hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.1 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 6\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/activity'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    1e-6,\n",
    "    #1e-5,\n",
    "    #1e-4,\n",
    "    2e-4,\n",
    "    #4e-4,\n",
    "    #5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    1.25e-1,\n",
    "    1.5e-1,\n",
    "    1.75e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T07:33:38.122383Z",
     "iopub.status.busy": "2025-01-20T07:33:38.121626Z",
     "iopub.status.idle": "2025-01-20T07:33:38.528595Z",
     "shell.execute_reply": "2025-01-20T07:33:38.528085Z",
     "shell.execute_reply.started": "2025-01-20T07:33:38.122363Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (4252, 561), y_train shape: (4252, 6)\n",
      "x_test shape: (1492, 561), y_test shape: (1492, 6)\n",
      "Normalized Training Set Mean and SD: [-3.78062012e-16  4.24406708e-16 -8.94834993e-16 -3.73313798e-15\n",
      " -5.67801267e-16 -5.81274341e-15  5.46834450e-16 -7.79662736e-17\n",
      " -4.56235864e-15 -2.24002900e-16  1.97127252e-15  6.71956480e-16\n",
      " -1.78335448e-17  2.23373635e-15 -2.57043263e-15 -7.34804712e-16\n",
      " -2.64437787e-15  3.22325004e-15 -1.01107322e-14 -5.14684459e-15\n",
      "  2.41473508e-15  1.19889465e-15 -1.18489936e-16  2.81994559e-18\n",
      "  2.36875430e-16  8.53816860e-17 -8.81180777e-16  1.07732366e-16\n",
      " -3.89570262e-17 -1.72330009e-17 -3.69778421e-16 -1.13904914e-15\n",
      " -4.69468720e-17 -7.17519490e-17 -3.39281232e-16  3.04971894e-17\n",
      "  4.78816318e-16  4.12387669e-16 -4.45721123e-16 -2.58586400e-16\n",
      "  1.78606999e-15  1.17602176e-16 -1.85489755e-16 -3.31382773e-16\n",
      " -9.09638728e-15 -8.70887976e-15 -3.61235031e-15  5.50918149e-15\n",
      "  5.52013489e-15 -7.14877097e-15 -6.45454214e-17 -5.56599817e-16\n",
      "  1.01016718e-15  1.13320036e-16 -4.20641884e-17 -6.22790207e-16\n",
      "  6.79189118e-16 -3.19802719e-15  3.44464187e-16  3.13000906e-16\n",
      " -1.68258582e-14  1.24727760e-15 -2.37512529e-15  4.33545747e-15\n",
      "  4.34955719e-15 -3.34491241e-15 -8.10375337e-16  8.41268103e-15\n",
      "  6.55469264e-16  1.54674016e-15 -5.37095193e-16 -2.34614251e-15\n",
      "  3.99074523e-15  8.15695374e-16 -1.67368993e-15  1.91087869e-15\n",
      " -1.49493671e-15  1.15487216e-16 -1.32798549e-16 -1.37341795e-17\n",
      "  4.34240615e-16 -1.02221396e-16  1.44048957e-17 -5.35089899e-15\n",
      "  1.32991767e-15 -4.16396300e-15 -1.46020961e-15  3.16303897e-15\n",
      " -5.34149917e-15 -5.52239346e-17 -7.35026652e-15 -4.02831839e-15\n",
      "  1.47806926e-15 -2.16300271e-15  3.45683553e-15 -1.53616536e-15\n",
      "  6.08664368e-15 -3.08953761e-15 -1.26053135e-14  1.07802865e-15\n",
      "  2.00902846e-15 -1.32910825e-15  7.68174068e-17 -1.37498458e-16\n",
      "  1.05382411e-16  5.22212147e-19  4.22991839e-16  1.40213961e-16\n",
      "  5.04713960e-16  1.93166273e-16 -5.14875066e-16  2.05328594e-15\n",
      "  2.78545348e-15 -6.30832274e-17  1.59533200e-15  2.71445874e-16\n",
      "  1.44130553e-17  2.27074161e-16 -2.65779872e-16  4.14401449e-16\n",
      "  2.59605224e-16 -1.49368871e-15  1.48915811e-15 -4.36256028e-15\n",
      " -2.31005765e-15 -3.50644568e-15 -1.03883662e-15 -2.17733744e-15\n",
      " -3.01446962e-16  3.82739727e-15  6.27933996e-15  4.86101177e-15\n",
      "  5.01114776e-15  8.71258746e-16  2.46507633e-15 -1.25017588e-16\n",
      " -1.01216725e-14 -5.23212183e-15 -3.42652111e-15 -4.22600180e-16\n",
      " -1.38903209e-15 -8.01151765e-15 -1.03189120e-16 -2.63612692e-16\n",
      "  4.73646417e-17 -4.40172619e-16 -5.76483044e-16 -4.95187668e-17\n",
      "  6.78875791e-18 -6.56681775e-16  1.47472710e-16  1.46898277e-16\n",
      "  3.74559274e-15 -3.19593834e-17  1.79902085e-16 -5.48453307e-17\n",
      " -1.16589084e-15  4.00968358e-16 -1.50814868e-16 -7.86582046e-17\n",
      " -7.81022119e-16 -4.71294831e-16 -4.65927061e-16 -9.10215772e-17\n",
      " -1.55875104e-15 -1.11210299e-15  3.61806853e-15 -8.58323551e-15\n",
      " -7.61098094e-15  1.00904442e-15 -1.16218574e-14  4.33785964e-15\n",
      " -8.89045292e-15  1.40342165e-14 -1.19461251e-15 -7.28026398e-15\n",
      " -3.24899509e-15  2.73900271e-15 -1.34124446e-14 -1.79001269e-15\n",
      " -9.35240178e-15 -8.46009789e-15 -2.04446056e-16  4.35368267e-16\n",
      "  9.88547594e-17  1.31075249e-16  8.81833542e-16  4.08317678e-16\n",
      " -1.67603989e-16 -3.30978059e-16  1.58713327e-15  4.30720579e-16\n",
      "  2.78456572e-16  5.78976607e-16 -5.54589300e-17 -6.55154304e-16\n",
      "  9.31104258e-17 -1.24182049e-16  8.71154304e-16  3.38262918e-16\n",
      "  1.53566926e-15  1.31472130e-15 -4.19989119e-16  3.11630099e-15\n",
      " -3.16906791e-14  1.53566926e-15  4.75560325e-15  2.96224840e-16\n",
      " -2.06117134e-16  1.21362103e-16  5.99969536e-16 -2.95467633e-16\n",
      " -4.10876517e-16  1.53566926e-15  1.31472130e-15 -4.19989119e-16\n",
      "  3.11630099e-15 -3.16906791e-14  1.53566926e-15  4.75560325e-15\n",
      "  2.96224840e-16 -2.06117134e-16  1.21362103e-16  5.99969536e-16\n",
      " -2.95467633e-16 -4.10876517e-16  2.89300307e-15 -2.59098168e-15\n",
      " -6.23581358e-15  3.77366164e-15  3.02162393e-15  2.89300307e-15\n",
      " -6.70207070e-15 -3.80128666e-15 -8.98727105e-17 -3.27061468e-16\n",
      " -7.22167178e-16 -5.25528194e-16  1.51128195e-16 -2.95436300e-15\n",
      " -7.86817042e-16 -8.78621937e-16  2.15986944e-15  3.53057188e-15\n",
      " -2.95436300e-15 -8.47827087e-15  1.76768812e-17  4.31765003e-16\n",
      " -2.06587125e-16  5.28844241e-16  1.26662556e-16 -9.21704440e-17\n",
      "  1.19832021e-15  1.70152384e-15  7.82926561e-16 -4.45034414e-15\n",
      " -3.69548648e-15  1.19832021e-15 -1.35023173e-15 -5.26342845e-15\n",
      "  5.25606526e-16 -4.59724242e-15 -4.84195103e-16  1.27511151e-16\n",
      "  2.87999999e-16 -1.98887107e-15  3.58806744e-15 -1.23534506e-15\n",
      " -3.35424695e-15 -2.34525475e-16  2.76250226e-16 -2.31339981e-17\n",
      "  1.45691967e-15  2.62521268e-15 -6.51605873e-15 -1.74262193e-15\n",
      "  1.18458603e-15 -9.81628283e-15 -1.00990607e-15 -1.79698944e-14\n",
      "  3.52414867e-15 -1.05130444e-14 -2.37619582e-15  7.13106797e-16\n",
      "  4.35864369e-16 -2.09114632e-15 -3.13118403e-16  1.53499039e-15\n",
      " -1.68622302e-15  9.31156479e-16  2.02683720e-14  1.86394748e-14\n",
      " -2.55100634e-17  6.78823570e-16  6.49109699e-17  4.13905348e-16\n",
      " -4.64690479e-16 -2.83221758e-16  9.54133814e-16  2.50372003e-15\n",
      "  1.28526854e-15  7.91960832e-16 -2.07839129e-14 -3.54054614e-15\n",
      " -7.42664005e-16  3.24612293e-15  5.63560905e-15 -1.91490233e-14\n",
      "  6.39453996e-15 -7.50826964e-14 -1.76082886e-14  2.26747125e-15\n",
      "  5.61620887e-15  4.01427088e-15 -1.93636003e-14 -1.53331931e-15\n",
      " -6.42347051e-15 -9.89571130e-15 -1.24334012e-14  6.27947052e-15\n",
      "  2.31757751e-16 -8.67731203e-15  1.94647528e-14 -6.81675632e-14\n",
      " -6.23252364e-15  1.22867902e-14 -1.59036315e-14  1.20641450e-14\n",
      "  5.65320760e-16 -8.01441593e-15 -7.61098094e-16 -7.08049173e-15\n",
      "  1.22569980e-14  2.97068474e-14 -1.87324286e-14  1.35827902e-14\n",
      " -1.14096304e-14  2.25751006e-15 -9.25929136e-15  1.32254665e-14\n",
      " -1.50313022e-14 -1.62120500e-14  5.47526381e-15 -5.61276227e-15\n",
      "  1.61157280e-15  8.88857295e-16  4.67513036e-15 -4.70669808e-16\n",
      " -3.13100126e-15 -1.08089037e-14  9.54943243e-16 -4.43360724e-15\n",
      " -5.27434269e-18 -1.28508315e-14  1.16698749e-15 -6.55467632e-15\n",
      " -1.26397272e-14 -5.39640977e-15 -1.77954233e-15 -2.64845112e-15\n",
      "  2.87989555e-15  1.37707343e-15 -1.42744079e-14 -1.49799165e-15\n",
      " -4.01727360e-15 -4.87639092e-15  1.60580235e-15  2.54045765e-15\n",
      "  1.44610988e-15  2.64813780e-15 -1.25048921e-15  5.60119527e-15\n",
      " -2.53272891e-17  7.88070351e-16  2.70192565e-16 -3.23743789e-15\n",
      "  1.60951528e-14  4.79965184e-15  1.26394139e-14 -4.66635719e-16\n",
      " -1.46229454e-14 -2.84112130e-14  3.36617950e-15  1.15646752e-14\n",
      " -7.59338239e-15  7.20783316e-16 -3.14572764e-15 -1.01029512e-14\n",
      " -1.34395291e-13  1.57221367e-14 -2.15906001e-15  2.85336717e-15\n",
      "  4.13738240e-15  1.08738408e-14 -4.36151585e-16 -1.18120210e-14\n",
      "  1.17743173e-15 -2.61831948e-15 -2.22096826e-15 -9.32639562e-15\n",
      " -2.25700090e-16 -1.26950556e-14 -3.84114320e-14 -7.12921412e-15\n",
      " -4.55079164e-15 -4.43183172e-15 -5.82752201e-15 -1.47716061e-14\n",
      " -2.72276191e-15  3.75034487e-15 -7.27198692e-15 -1.44096348e-14\n",
      "  8.20382228e-15  1.63588177e-15  2.36120833e-15 -1.49411162e-14\n",
      " -2.44934861e-14  1.44624043e-15  6.98735519e-15  1.56029940e-14\n",
      "  2.00539909e-15 -1.51646230e-14  1.02820961e-15 -4.91662736e-17\n",
      " -3.88008847e-15  1.70442212e-15 -1.91544804e-15  4.59504912e-15\n",
      "  8.37954666e-15 -1.62405367e-15 -5.32095012e-15 -3.04982338e-15\n",
      " -4.44089210e-15  4.60358729e-15 -4.24884858e-15 -5.01161775e-15\n",
      " -6.92844966e-15 -2.03770052e-14 -1.59494034e-15 -1.30238665e-14\n",
      " -8.73603479e-15 -1.66254592e-14 -2.93115067e-15 -1.13696812e-14\n",
      "  2.86258422e-15 -7.79140523e-17 -3.31239165e-16  5.21846599e-16\n",
      " -4.56177377e-14  3.66257993e-14  3.97142860e-14 -1.76351042e-16\n",
      "  5.44145057e-16  1.67682320e-16  8.33137259e-16  1.27795757e-15\n",
      " -2.92289972e-15 -3.16301286e-15 -9.89461466e-16  4.40290116e-15\n",
      " -4.88816680e-16 -1.27630215e-14 -2.27676663e-15 -2.29653236e-15\n",
      " -2.97337152e-15  3.50239854e-15  8.91468356e-16 -4.33598751e-14\n",
      " -1.19064370e-14 -3.79230461e-16 -5.53137550e-15 -2.88553544e-14\n",
      " -1.47033008e-14 -1.16223535e-15 -5.36912419e-16 -1.66226393e-14\n",
      " -3.22958186e-14 -8.28920396e-15 -1.70968471e-14  1.77617929e-14\n",
      "  6.53940161e-16 -5.47713855e-14  1.12975115e-14 -4.51240905e-15\n",
      "  1.27824609e-14  3.40351767e-15  4.09725040e-15  2.82976318e-14\n",
      " -2.03290400e-14 -2.57011669e-14 -7.53369354e-15 -1.88534513e-14\n",
      " -5.99556988e-15 -1.14901294e-14 -1.50593450e-14 -3.31039810e-14\n",
      " -1.83885388e-14  4.94208521e-15  3.65754777e-15 -7.95152853e-15\n",
      " -2.23646752e-14 -2.20371959e-14  4.14062011e-16 -7.61933633e-16\n",
      " -3.87209863e-15 -3.45135230e-15 -1.74544710e-14  4.14062011e-16\n",
      " -3.42364895e-15 -7.41076480e-15  4.07012147e-16  1.47768478e-14\n",
      " -6.67700451e-16  2.28311151e-16 -4.02040688e-15 -3.09724024e-15\n",
      " -2.75250190e-15 -1.58621940e-16 -3.32568195e-15 -2.69354414e-15\n",
      " -3.09724024e-15  1.51624297e-16 -4.67458203e-15  1.30887253e-15\n",
      " -1.94809361e-13  8.01908973e-16 -7.41175700e-16  3.76230352e-15\n",
      " -2.26321522e-15  1.59700308e-15 -6.20711802e-15 -3.85011350e-15\n",
      "  7.44021757e-16 -2.26321522e-15 -2.52627959e-15 -6.17098094e-16\n",
      " -2.09354850e-16  4.12291712e-15  8.17784222e-17  5.36541648e-15\n",
      "  3.91747886e-15 -4.02325293e-15  6.26411748e-15 -1.86591622e-15\n",
      " -1.21977530e-14 -7.85824839e-16 -4.02325293e-15 -7.83735990e-15\n",
      " -5.91744694e-16 -1.14218241e-15 -2.32138809e-13 -5.06023570e-17\n",
      "  2.86908576e-15  5.63380742e-15  8.41446960e-17  1.07314596e-17\n",
      "  1.91129646e-17  7.83318221e-18  3.84974795e-16  1.03481559e-15\n",
      "  2.54839528e-16] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (4252, 561)\n",
      "Train labels shape:  (4252, 6)\n",
      "Test data shape:  (1492, 561)\n",
      "Test labels shape:  (1492, 6)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from PIL import Image\n",
    "import tensorflow as tf\n",
    "\n",
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def preprocess_line(line):\n",
    "    # Split the line by commas to handle comma-separated values\n",
    "    preprocessed = line.strip().split(',')\n",
    "    return preprocessed\n",
    "\n",
    "def load_X_activity(filename):\n",
    "    # Read the file and preprocess each line\n",
    "    with open(filename, 'r') as file:\n",
    "        lines = file.readlines()\n",
    "    # Apply preprocessing to each line\n",
    "    preprocessed_lines = [preprocess_line(line) for line in lines]\n",
    "    # Convert each preprocessed line to a list of floats\n",
    "    data_list = [list(map(float, line)) for line in preprocessed_lines]\n",
    "    # Convert the list of lists into a NumPy array\n",
    "    data = np.array(data_list)\n",
    "    return data\n",
    "\n",
    "def load_activity(one_hot=True):\n",
    "    #uci_base_path = os.path.join(os.path.dirname(__file__), 'data', 'activity')\n",
    "    #uci_base_path = os.path.join('data', 'activity')\n",
    "    #x_train_path = os.path.join(uci_base_path, 'final_X_train.txt')\n",
    "    #x_test_path = os.path.join(uci_base_path, 'final_X_test.txt')\n",
    "    #y_train_path = os.path.join(uci_base_path, 'final_y_train.txt')\n",
    "    #y_test_path = os.path.join(uci_base_path, 'final_y_test.txt')\n",
    "\n",
    "    #x_train = load_X_activitiy(x_train_path)\n",
    "    #x_test = load_X_activitiy(x_test_path)\n",
    "    #y_train = np.loadtxt(y_train_path, delimiter='\\t', encoding='UTF-8') - 1\n",
    "    #y_test = np.loadtxt(y_test_path, delimiter='\\t', encoding='UTF-8') - 1\n",
    "    uci_base_path = os.path.join('data', 'activity')\n",
    "    x_train = np.loadtxt(os.path.join(uci_base_path, 'final_X_train.txt'), delimiter = ',', encoding = 'UTF-8')\n",
    "    x_test = np.loadtxt(os.path.join(uci_base_path, 'final_X_test.txt'), delimiter = ',', encoding = 'UTF-8')\n",
    "    y_train = np.loadtxt(os.path.join(uci_base_path, 'final_y_train.txt'), delimiter = ',', encoding = 'UTF-8') - 1\n",
    "    y_test = np.loadtxt(os.path.join(uci_base_path, 'final_y_test.txt'), delimiter = ',', encoding = 'UTF-8') - 1\n",
    "\n",
    "    # Ensure y_train and y_test are integers\n",
    "    y_train = y_train.astype(int)\n",
    "    y_test = y_test.astype(int)\n",
    "\n",
    "    # normalize data\n",
    "    scaler = StandardScaler().fit(x_train)\n",
    "    x_train = scaler.transform(x_train)\n",
    "    x_test = scaler.transform(x_test)\n",
    "\n",
    "    num_classes = np.unique(y_train).shape[0]\n",
    "\n",
    "    if one_hot:\n",
    "        y_train = one_hot_encode(y_train, num_classes)\n",
    "        y_test = one_hot_encode(y_test, num_classes)\n",
    "\n",
    "    print(\"x_train shape: {}, y_train shape: {}\".format(x_train.shape, y_train.shape))\n",
    "    print(\"x_test shape: {}, y_test shape: {}\".format(x_test.shape, y_test.shape))\n",
    "\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_activity()\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-19T16:32:42.430006Z",
     "iopub.status.busy": "2025-01-19T16:32:42.429379Z",
     "iopub.status.idle": "2025-01-19T16:32:52.220085Z",
     "shell.execute_reply": "2025-01-19T16:32:52.219657Z",
     "shell.execute_reply.started": "2025-01-19T16:32:42.429983Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 561)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               168600    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 6)                 606       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 199,306\n",
      "Trainable params: 199,306\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 1.000e-01\n",
      "Epoch 1/100\n",
      "17/17 [==============================] - 3s 3ms/step - loss: 1.5034 - accuracy: 0.5016\n",
      "\n",
      "Epoch 2: Current learning rate = 9.997e-02\n",
      "Epoch 2/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.8875 - accuracy: 0.6693\n",
      "\n",
      "Epoch 3: Current learning rate = 9.989e-02\n",
      "Epoch 3/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.4751 - accuracy: 0.8156\n",
      "\n",
      "Epoch 4: Current learning rate = 9.975e-02\n",
      "Epoch 4/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.3284 - accuracy: 0.8664\n",
      "\n",
      "Epoch 5: Current learning rate = 9.955e-02\n",
      "Epoch 5/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.2337 - accuracy: 0.9062\n",
      "\n",
      "Epoch 6: Current learning rate = 9.931e-02\n",
      "Epoch 6/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.1741 - accuracy: 0.9254\n",
      "\n",
      "Epoch 7: Current learning rate = 9.900e-02\n",
      "Epoch 7/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.1457 - accuracy: 0.9384\n",
      "\n",
      "Epoch 8: Current learning rate = 9.864e-02\n",
      "Epoch 8/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.1358 - accuracy: 0.9473\n",
      "\n",
      "Epoch 9: Current learning rate = 9.823e-02\n",
      "Epoch 9/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.1024 - accuracy: 0.9591\n",
      "\n",
      "Epoch 10: Current learning rate = 9.776e-02\n",
      "Epoch 10/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0865 - accuracy: 0.9687\n",
      "\n",
      "Epoch 11: Current learning rate = 9.724e-02\n",
      "Epoch 11/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0731 - accuracy: 0.9708\n",
      "\n",
      "Epoch 12: Current learning rate = 9.667e-02\n",
      "Epoch 12/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0572 - accuracy: 0.9805\n",
      "\n",
      "Epoch 13: Current learning rate = 9.604e-02\n",
      "Epoch 13/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0459 - accuracy: 0.9833\n",
      "\n",
      "Epoch 14: Current learning rate = 9.537e-02\n",
      "Epoch 14/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0336 - accuracy: 0.9885\n",
      "\n",
      "Epoch 15: Current learning rate = 9.464e-02\n",
      "Epoch 15/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0269 - accuracy: 0.9904\n",
      "\n",
      "Epoch 16: Current learning rate = 9.386e-02\n",
      "Epoch 16/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0179 - accuracy: 0.9953\n",
      "\n",
      "Epoch 17: Current learning rate = 9.304e-02\n",
      "Epoch 17/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0138 - accuracy: 0.9958\n",
      "\n",
      "Epoch 18: Current learning rate = 9.216e-02\n",
      "Epoch 18/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0132 - accuracy: 0.9962\n",
      "\n",
      "Epoch 19: Current learning rate = 9.124e-02\n",
      "Epoch 19/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0085 - accuracy: 0.9986\n",
      "\n",
      "Epoch 20: Current learning rate = 9.028e-02\n",
      "Epoch 20/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0057 - accuracy: 0.9988\n",
      "\n",
      "Epoch 21: Current learning rate = 8.927e-02\n",
      "Epoch 21/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0037 - accuracy: 0.9995\n",
      "\n",
      "Epoch 22: Current learning rate = 8.821e-02\n",
      "Epoch 22/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0029 - accuracy: 0.9993\n",
      "\n",
      "Epoch 23: Current learning rate = 8.711e-02\n",
      "Epoch 23/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0025 - accuracy: 0.9995\n",
      "\n",
      "Epoch 24: Current learning rate = 8.597e-02\n",
      "Epoch 24/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0022 - accuracy: 0.9995\n",
      "\n",
      "Epoch 25: Current learning rate = 8.480e-02\n",
      "Epoch 25/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0020 - accuracy: 0.9995\n",
      "\n",
      "Epoch 26: Current learning rate = 8.358e-02\n",
      "Epoch 26/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0019 - accuracy: 0.9995\n",
      "\n",
      "Epoch 27: Current learning rate = 8.232e-02\n",
      "Epoch 27/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0018 - accuracy: 0.9995\n",
      "\n",
      "Epoch 28: Current learning rate = 8.103e-02\n",
      "Epoch 28/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0017 - accuracy: 0.9995\n",
      "\n",
      "Epoch 29: Current learning rate = 7.971e-02\n",
      "Epoch 29/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 0.9998\n",
      "\n",
      "Epoch 30: Current learning rate = 7.835e-02\n",
      "Epoch 30/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 0.9995\n",
      "\n",
      "Epoch 31: Current learning rate = 7.696e-02\n",
      "Epoch 31/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0014 - accuracy: 0.9998\n",
      "\n",
      "Epoch 32: Current learning rate = 7.554e-02\n",
      "Epoch 32/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0014 - accuracy: 0.9998\n",
      "\n",
      "Epoch 33: Current learning rate = 7.409e-02\n",
      "Epoch 33/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0013 - accuracy: 0.9998\n",
      "\n",
      "Epoch 34: Current learning rate = 7.261e-02\n",
      "Epoch 34/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0012 - accuracy: 0.9998\n",
      "\n",
      "Epoch 35: Current learning rate = 7.111e-02\n",
      "Epoch 35/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0012 - accuracy: 0.9998\n",
      "\n",
      "Epoch 36: Current learning rate = 6.959e-02\n",
      "Epoch 36/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 0.9998\n",
      "\n",
      "Epoch 37: Current learning rate = 6.804e-02\n",
      "Epoch 37/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 0.9998\n",
      "\n",
      "Epoch 38: Current learning rate = 6.647e-02\n",
      "Epoch 38/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 0.9998\n",
      "\n",
      "Epoch 39: Current learning rate = 6.489e-02\n",
      "Epoch 39/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0011 - accuracy: 0.9998\n",
      "\n",
      "Epoch 40: Current learning rate = 6.329e-02\n",
      "Epoch 40/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0010 - accuracy: 0.9998\n",
      "\n",
      "Epoch 41: Current learning rate = 6.167e-02\n",
      "Epoch 41/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.0010 - accuracy: 0.9998\n",
      "\n",
      "Epoch 42: Current learning rate = 6.004e-02\n",
      "Epoch 42/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 9.6196e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 43: Current learning rate = 5.840e-02\n",
      "Epoch 43/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 9.3525e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 44: Current learning rate = 5.675e-02\n",
      "Epoch 44/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 9.1772e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 45: Current learning rate = 5.510e-02\n",
      "Epoch 45/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.9652e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 46: Current learning rate = 5.343e-02\n",
      "Epoch 46/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.7981e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 47: Current learning rate = 5.177e-02\n",
      "Epoch 47/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.6249e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 48: Current learning rate = 5.010e-02\n",
      "Epoch 48/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.5020e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 49: Current learning rate = 4.843e-02\n",
      "Epoch 49/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.3522e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 50: Current learning rate = 4.676e-02\n",
      "Epoch 50/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.2504e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 51: Current learning rate = 4.510e-02\n",
      "Epoch 51/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.1639e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 52: Current learning rate = 4.344e-02\n",
      "Epoch 52/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 8.0459e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 53: Current learning rate = 4.179e-02\n",
      "Epoch 53/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.9331e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 54: Current learning rate = 4.015e-02\n",
      "Epoch 54/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.8874e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 55: Current learning rate = 3.852e-02\n",
      "Epoch 55/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.7790e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 56: Current learning rate = 3.690e-02\n",
      "Epoch 56/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.7046e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 57: Current learning rate = 3.530e-02\n",
      "Epoch 57/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.6311e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 58: Current learning rate = 3.371e-02\n",
      "Epoch 58/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.5742e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 59: Current learning rate = 3.214e-02\n",
      "Epoch 59/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.5204e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 60: Current learning rate = 3.059e-02\n",
      "Epoch 60/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.4522e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 61: Current learning rate = 2.907e-02\n",
      "Epoch 61/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.4121e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 62: Current learning rate = 2.756e-02\n",
      "Epoch 62/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.3499e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 63: Current learning rate = 2.608e-02\n",
      "Epoch 63/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.3045e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 64: Current learning rate = 2.463e-02\n",
      "Epoch 64/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.2693e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 65: Current learning rate = 2.321e-02\n",
      "Epoch 65/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.2457e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 66: Current learning rate = 2.181e-02\n",
      "Epoch 66/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.1897e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 67: Current learning rate = 2.045e-02\n",
      "Epoch 67/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.1566e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 68: Current learning rate = 1.912e-02\n",
      "Epoch 68/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.1347e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 69: Current learning rate = 1.783e-02\n",
      "Epoch 69/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.1063e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 70: Current learning rate = 1.657e-02\n",
      "Epoch 70/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.0816e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 71: Current learning rate = 1.535e-02\n",
      "Epoch 71/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.0513e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 72: Current learning rate = 1.416e-02\n",
      "Epoch 72/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.0273e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 73: Current learning rate = 1.302e-02\n",
      "Epoch 73/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 7.0097e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 74: Current learning rate = 1.192e-02\n",
      "Epoch 74/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9927e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 75: Current learning rate = 1.086e-02\n",
      "Epoch 75/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9819e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 76: Current learning rate = 9.840e-03\n",
      "Epoch 76/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9643e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 77: Current learning rate = 8.868e-03\n",
      "Epoch 77/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9463e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 78: Current learning rate = 7.942e-03\n",
      "Epoch 78/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9349e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 79: Current learning rate = 7.063e-03\n",
      "Epoch 79/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9239e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 80: Current learning rate = 6.232e-03\n",
      "Epoch 80/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9146e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 81: Current learning rate = 5.450e-03\n",
      "Epoch 81/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.9089e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 82: Current learning rate = 4.717e-03\n",
      "Epoch 82/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8984e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 83: Current learning rate = 4.035e-03\n",
      "Epoch 83/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8921e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 84: Current learning rate = 3.404e-03\n",
      "Epoch 84/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8869e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 85: Current learning rate = 2.824e-03\n",
      "Epoch 85/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8832e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 86: Current learning rate = 2.298e-03\n",
      "Epoch 86/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8785e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 87: Current learning rate = 1.824e-03\n",
      "Epoch 87/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8757e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 88: Current learning rate = 1.405e-03\n",
      "Epoch 88/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8731e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 89: Current learning rate = 1.039e-03\n",
      "Epoch 89/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8710e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 90: Current learning rate = 7.277e-04\n",
      "Epoch 90/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8698e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 91: Current learning rate = 4.715e-04\n",
      "Epoch 91/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8687e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 92: Current learning rate = 2.705e-04\n",
      "Epoch 92/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8680e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 93: Current learning rate = 1.249e-04\n",
      "Epoch 93/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8675e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 94: Current learning rate = 3.479e-05\n",
      "Epoch 94/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8675e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 95: Current learning rate = 3.844e-07\n",
      "Epoch 95/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8673e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 96: Current learning rate = 0.000e+00\n",
      "Epoch 96/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8674e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8673e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8673e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8674e-04 - accuracy: 0.9998\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 6.8674e-04 - accuracy: 0.9998\n",
      "47/47 [==============================] - 0s 2ms/step - loss: 0.9284 - accuracy: 0.8592\n",
      "\n",
      "Test loss 0.9284167289733887\n",
      "Test accuracy 0.8592493534088135\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on activity\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_activity'\n",
    "#DEPTH = DEPTH\n",
    "LA = 0 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR/1 #INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "EPOCHS = EPOCHS\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100, use_bias=True)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T16:34:56.523126Z",
     "iopub.status.busy": "2025-01-19T16:34:56.522435Z",
     "iopub.status.idle": "2025-01-19T17:23:22.362540Z",
     "shell.execute_reply": "2025-01-19T17:23:22.361839Z",
     "shell.execute_reply.started": "2025-01-19T16:34:56.523096Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED\n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_activity'  # Keep original dataset name\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                              dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                              large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                   loss='categorical_crossentropy',\n",
    "                   metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                   callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                            'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T07:33:50.375235Z",
     "iopub.status.busy": "2025-01-20T07:33:50.374425Z",
     "iopub.status.idle": "2025-01-20T07:33:50.624337Z",
     "shell.execute_reply": "2025-01-20T07:33:50.623662Z",
     "shell.execute_reply.started": "2025-01-20T07:33:50.375210Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (4252, 561), y_train shape: (4252,)\n",
      "x_test shape: (1492, 561), y_test shape: (1492,)\n",
      "Normalized Training Set Mean and SD: [-3.78062012e-16  4.24406708e-16 -8.94834993e-16 -3.73313798e-15\n",
      " -5.67801267e-16 -5.81274341e-15  5.46834450e-16 -7.79662736e-17\n",
      " -4.56235864e-15 -2.24002900e-16  1.97127252e-15  6.71956480e-16\n",
      " -1.78335448e-17  2.23373635e-15 -2.57043263e-15 -7.34804712e-16\n",
      " -2.64437787e-15  3.22325004e-15 -1.01107322e-14 -5.14684459e-15\n",
      "  2.41473508e-15  1.19889465e-15 -1.18489936e-16  2.81994559e-18\n",
      "  2.36875430e-16  8.53816860e-17 -8.81180777e-16  1.07732366e-16\n",
      " -3.89570262e-17 -1.72330009e-17 -3.69778421e-16 -1.13904914e-15\n",
      " -4.69468720e-17 -7.17519490e-17 -3.39281232e-16  3.04971894e-17\n",
      "  4.78816318e-16  4.12387669e-16 -4.45721123e-16 -2.58586400e-16\n",
      "  1.78606999e-15  1.17602176e-16 -1.85489755e-16 -3.31382773e-16\n",
      " -9.09638728e-15 -8.70887976e-15 -3.61235031e-15  5.50918149e-15\n",
      "  5.52013489e-15 -7.14877097e-15 -6.45454214e-17 -5.56599817e-16\n",
      "  1.01016718e-15  1.13320036e-16 -4.20641884e-17 -6.22790207e-16\n",
      "  6.79189118e-16 -3.19802719e-15  3.44464187e-16  3.13000906e-16\n",
      " -1.68258582e-14  1.24727760e-15 -2.37512529e-15  4.33545747e-15\n",
      "  4.34955719e-15 -3.34491241e-15 -8.10375337e-16  8.41268103e-15\n",
      "  6.55469264e-16  1.54674016e-15 -5.37095193e-16 -2.34614251e-15\n",
      "  3.99074523e-15  8.15695374e-16 -1.67368993e-15  1.91087869e-15\n",
      " -1.49493671e-15  1.15487216e-16 -1.32798549e-16 -1.37341795e-17\n",
      "  4.34240615e-16 -1.02221396e-16  1.44048957e-17 -5.35089899e-15\n",
      "  1.32991767e-15 -4.16396300e-15 -1.46020961e-15  3.16303897e-15\n",
      " -5.34149917e-15 -5.52239346e-17 -7.35026652e-15 -4.02831839e-15\n",
      "  1.47806926e-15 -2.16300271e-15  3.45683553e-15 -1.53616536e-15\n",
      "  6.08664368e-15 -3.08953761e-15 -1.26053135e-14  1.07802865e-15\n",
      "  2.00902846e-15 -1.32910825e-15  7.68174068e-17 -1.37498458e-16\n",
      "  1.05382411e-16  5.22212147e-19  4.22991839e-16  1.40213961e-16\n",
      "  5.04713960e-16  1.93166273e-16 -5.14875066e-16  2.05328594e-15\n",
      "  2.78545348e-15 -6.30832274e-17  1.59533200e-15  2.71445874e-16\n",
      "  1.44130553e-17  2.27074161e-16 -2.65779872e-16  4.14401449e-16\n",
      "  2.59605224e-16 -1.49368871e-15  1.48915811e-15 -4.36256028e-15\n",
      " -2.31005765e-15 -3.50644568e-15 -1.03883662e-15 -2.17733744e-15\n",
      " -3.01446962e-16  3.82739727e-15  6.27933996e-15  4.86101177e-15\n",
      "  5.01114776e-15  8.71258746e-16  2.46507633e-15 -1.25017588e-16\n",
      " -1.01216725e-14 -5.23212183e-15 -3.42652111e-15 -4.22600180e-16\n",
      " -1.38903209e-15 -8.01151765e-15 -1.03189120e-16 -2.63612692e-16\n",
      "  4.73646417e-17 -4.40172619e-16 -5.76483044e-16 -4.95187668e-17\n",
      "  6.78875791e-18 -6.56681775e-16  1.47472710e-16  1.46898277e-16\n",
      "  3.74559274e-15 -3.19593834e-17  1.79902085e-16 -5.48453307e-17\n",
      " -1.16589084e-15  4.00968358e-16 -1.50814868e-16 -7.86582046e-17\n",
      " -7.81022119e-16 -4.71294831e-16 -4.65927061e-16 -9.10215772e-17\n",
      " -1.55875104e-15 -1.11210299e-15  3.61806853e-15 -8.58323551e-15\n",
      " -7.61098094e-15  1.00904442e-15 -1.16218574e-14  4.33785964e-15\n",
      " -8.89045292e-15  1.40342165e-14 -1.19461251e-15 -7.28026398e-15\n",
      " -3.24899509e-15  2.73900271e-15 -1.34124446e-14 -1.79001269e-15\n",
      " -9.35240178e-15 -8.46009789e-15 -2.04446056e-16  4.35368267e-16\n",
      "  9.88547594e-17  1.31075249e-16  8.81833542e-16  4.08317678e-16\n",
      " -1.67603989e-16 -3.30978059e-16  1.58713327e-15  4.30720579e-16\n",
      "  2.78456572e-16  5.78976607e-16 -5.54589300e-17 -6.55154304e-16\n",
      "  9.31104258e-17 -1.24182049e-16  8.71154304e-16  3.38262918e-16\n",
      "  1.53566926e-15  1.31472130e-15 -4.19989119e-16  3.11630099e-15\n",
      " -3.16906791e-14  1.53566926e-15  4.75560325e-15  2.96224840e-16\n",
      " -2.06117134e-16  1.21362103e-16  5.99969536e-16 -2.95467633e-16\n",
      " -4.10876517e-16  1.53566926e-15  1.31472130e-15 -4.19989119e-16\n",
      "  3.11630099e-15 -3.16906791e-14  1.53566926e-15  4.75560325e-15\n",
      "  2.96224840e-16 -2.06117134e-16  1.21362103e-16  5.99969536e-16\n",
      " -2.95467633e-16 -4.10876517e-16  2.89300307e-15 -2.59098168e-15\n",
      " -6.23581358e-15  3.77366164e-15  3.02162393e-15  2.89300307e-15\n",
      " -6.70207070e-15 -3.80128666e-15 -8.98727105e-17 -3.27061468e-16\n",
      " -7.22167178e-16 -5.25528194e-16  1.51128195e-16 -2.95436300e-15\n",
      " -7.86817042e-16 -8.78621937e-16  2.15986944e-15  3.53057188e-15\n",
      " -2.95436300e-15 -8.47827087e-15  1.76768812e-17  4.31765003e-16\n",
      " -2.06587125e-16  5.28844241e-16  1.26662556e-16 -9.21704440e-17\n",
      "  1.19832021e-15  1.70152384e-15  7.82926561e-16 -4.45034414e-15\n",
      " -3.69548648e-15  1.19832021e-15 -1.35023173e-15 -5.26342845e-15\n",
      "  5.25606526e-16 -4.59724242e-15 -4.84195103e-16  1.27511151e-16\n",
      "  2.87999999e-16 -1.98887107e-15  3.58806744e-15 -1.23534506e-15\n",
      " -3.35424695e-15 -2.34525475e-16  2.76250226e-16 -2.31339981e-17\n",
      "  1.45691967e-15  2.62521268e-15 -6.51605873e-15 -1.74262193e-15\n",
      "  1.18458603e-15 -9.81628283e-15 -1.00990607e-15 -1.79698944e-14\n",
      "  3.52414867e-15 -1.05130444e-14 -2.37619582e-15  7.13106797e-16\n",
      "  4.35864369e-16 -2.09114632e-15 -3.13118403e-16  1.53499039e-15\n",
      " -1.68622302e-15  9.31156479e-16  2.02683720e-14  1.86394748e-14\n",
      " -2.55100634e-17  6.78823570e-16  6.49109699e-17  4.13905348e-16\n",
      " -4.64690479e-16 -2.83221758e-16  9.54133814e-16  2.50372003e-15\n",
      "  1.28526854e-15  7.91960832e-16 -2.07839129e-14 -3.54054614e-15\n",
      " -7.42664005e-16  3.24612293e-15  5.63560905e-15 -1.91490233e-14\n",
      "  6.39453996e-15 -7.50826964e-14 -1.76082886e-14  2.26747125e-15\n",
      "  5.61620887e-15  4.01427088e-15 -1.93636003e-14 -1.53331931e-15\n",
      " -6.42347051e-15 -9.89571130e-15 -1.24334012e-14  6.27947052e-15\n",
      "  2.31757751e-16 -8.67731203e-15  1.94647528e-14 -6.81675632e-14\n",
      " -6.23252364e-15  1.22867902e-14 -1.59036315e-14  1.20641450e-14\n",
      "  5.65320760e-16 -8.01441593e-15 -7.61098094e-16 -7.08049173e-15\n",
      "  1.22569980e-14  2.97068474e-14 -1.87324286e-14  1.35827902e-14\n",
      " -1.14096304e-14  2.25751006e-15 -9.25929136e-15  1.32254665e-14\n",
      " -1.50313022e-14 -1.62120500e-14  5.47526381e-15 -5.61276227e-15\n",
      "  1.61157280e-15  8.88857295e-16  4.67513036e-15 -4.70669808e-16\n",
      " -3.13100126e-15 -1.08089037e-14  9.54943243e-16 -4.43360724e-15\n",
      " -5.27434269e-18 -1.28508315e-14  1.16698749e-15 -6.55467632e-15\n",
      " -1.26397272e-14 -5.39640977e-15 -1.77954233e-15 -2.64845112e-15\n",
      "  2.87989555e-15  1.37707343e-15 -1.42744079e-14 -1.49799165e-15\n",
      " -4.01727360e-15 -4.87639092e-15  1.60580235e-15  2.54045765e-15\n",
      "  1.44610988e-15  2.64813780e-15 -1.25048921e-15  5.60119527e-15\n",
      " -2.53272891e-17  7.88070351e-16  2.70192565e-16 -3.23743789e-15\n",
      "  1.60951528e-14  4.79965184e-15  1.26394139e-14 -4.66635719e-16\n",
      " -1.46229454e-14 -2.84112130e-14  3.36617950e-15  1.15646752e-14\n",
      " -7.59338239e-15  7.20783316e-16 -3.14572764e-15 -1.01029512e-14\n",
      " -1.34395291e-13  1.57221367e-14 -2.15906001e-15  2.85336717e-15\n",
      "  4.13738240e-15  1.08738408e-14 -4.36151585e-16 -1.18120210e-14\n",
      "  1.17743173e-15 -2.61831948e-15 -2.22096826e-15 -9.32639562e-15\n",
      " -2.25700090e-16 -1.26950556e-14 -3.84114320e-14 -7.12921412e-15\n",
      " -4.55079164e-15 -4.43183172e-15 -5.82752201e-15 -1.47716061e-14\n",
      " -2.72276191e-15  3.75034487e-15 -7.27198692e-15 -1.44096348e-14\n",
      "  8.20382228e-15  1.63588177e-15  2.36120833e-15 -1.49411162e-14\n",
      " -2.44934861e-14  1.44624043e-15  6.98735519e-15  1.56029940e-14\n",
      "  2.00539909e-15 -1.51646230e-14  1.02820961e-15 -4.91662736e-17\n",
      " -3.88008847e-15  1.70442212e-15 -1.91544804e-15  4.59504912e-15\n",
      "  8.37954666e-15 -1.62405367e-15 -5.32095012e-15 -3.04982338e-15\n",
      " -4.44089210e-15  4.60358729e-15 -4.24884858e-15 -5.01161775e-15\n",
      " -6.92844966e-15 -2.03770052e-14 -1.59494034e-15 -1.30238665e-14\n",
      " -8.73603479e-15 -1.66254592e-14 -2.93115067e-15 -1.13696812e-14\n",
      "  2.86258422e-15 -7.79140523e-17 -3.31239165e-16  5.21846599e-16\n",
      " -4.56177377e-14  3.66257993e-14  3.97142860e-14 -1.76351042e-16\n",
      "  5.44145057e-16  1.67682320e-16  8.33137259e-16  1.27795757e-15\n",
      " -2.92289972e-15 -3.16301286e-15 -9.89461466e-16  4.40290116e-15\n",
      " -4.88816680e-16 -1.27630215e-14 -2.27676663e-15 -2.29653236e-15\n",
      " -2.97337152e-15  3.50239854e-15  8.91468356e-16 -4.33598751e-14\n",
      " -1.19064370e-14 -3.79230461e-16 -5.53137550e-15 -2.88553544e-14\n",
      " -1.47033008e-14 -1.16223535e-15 -5.36912419e-16 -1.66226393e-14\n",
      " -3.22958186e-14 -8.28920396e-15 -1.70968471e-14  1.77617929e-14\n",
      "  6.53940161e-16 -5.47713855e-14  1.12975115e-14 -4.51240905e-15\n",
      "  1.27824609e-14  3.40351767e-15  4.09725040e-15  2.82976318e-14\n",
      " -2.03290400e-14 -2.57011669e-14 -7.53369354e-15 -1.88534513e-14\n",
      " -5.99556988e-15 -1.14901294e-14 -1.50593450e-14 -3.31039810e-14\n",
      " -1.83885388e-14  4.94208521e-15  3.65754777e-15 -7.95152853e-15\n",
      " -2.23646752e-14 -2.20371959e-14  4.14062011e-16 -7.61933633e-16\n",
      " -3.87209863e-15 -3.45135230e-15 -1.74544710e-14  4.14062011e-16\n",
      " -3.42364895e-15 -7.41076480e-15  4.07012147e-16  1.47768478e-14\n",
      " -6.67700451e-16  2.28311151e-16 -4.02040688e-15 -3.09724024e-15\n",
      " -2.75250190e-15 -1.58621940e-16 -3.32568195e-15 -2.69354414e-15\n",
      " -3.09724024e-15  1.51624297e-16 -4.67458203e-15  1.30887253e-15\n",
      " -1.94809361e-13  8.01908973e-16 -7.41175700e-16  3.76230352e-15\n",
      " -2.26321522e-15  1.59700308e-15 -6.20711802e-15 -3.85011350e-15\n",
      "  7.44021757e-16 -2.26321522e-15 -2.52627959e-15 -6.17098094e-16\n",
      " -2.09354850e-16  4.12291712e-15  8.17784222e-17  5.36541648e-15\n",
      "  3.91747886e-15 -4.02325293e-15  6.26411748e-15 -1.86591622e-15\n",
      " -1.21977530e-14 -7.85824839e-16 -4.02325293e-15 -7.83735990e-15\n",
      " -5.91744694e-16 -1.14218241e-15 -2.32138809e-13 -5.06023570e-17\n",
      "  2.86908576e-15  5.63380742e-15  8.41446960e-17  1.07314596e-17\n",
      "  1.91129646e-17  7.83318221e-18  3.84974795e-16  1.03481559e-15\n",
      "  2.54839528e-16] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (4252, 561)\n",
      "Train labels shape:  (4252,)\n",
      "Test data shape:  (1492, 561)\n",
      "Test labels shape:  (1492,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_activity(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T07:33:55.131191Z",
     "iopub.status.busy": "2025-01-20T07:33:55.130619Z",
     "iopub.status.idle": "2025-01-20T10:28:47.709772Z",
     "shell.execute_reply": "2025-01-20T10:28:47.709140Z",
     "shell.execute_reply.started": "2025-01-20T07:33:55.131166Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_dnn, hsic_svm, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/activity/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"LassoNet\", lassoNet, False),\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False),\n",
    "        #(\"LassoNet\", lassoNet, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"ACTIVITY\": load_activity\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'activity')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
