{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T04:02:31.238448Z",
     "iopub.status.busy": "2025-01-19T04:02:31.237804Z",
     "iopub.status.idle": "2025-01-19T04:02:31.246772Z",
     "shell.execute_reply": "2025-01-19T04:02:31.246222Z",
     "shell.execute_reply.started": "2025-01-19T04:02:31.238424Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn, hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.2 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 20\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/coil20/'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    #1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    #4e-4,\n",
    "    5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    1.25e-1,\n",
    "    1.5e-1,\n",
    "    1.75e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T04:02:39.096265Z",
     "iopub.status.busy": "2025-01-19T04:02:39.095503Z",
     "iopub.status.idle": "2025-01-19T04:02:39.472999Z",
     "shell.execute_reply": "2025-01-19T04:02:39.472371Z",
     "shell.execute_reply.started": "2025-01-19T04:02:39.096231Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train shape: (1152, 400), Y_train shape: (1152, 20)\n",
      "X_test shape: (288, 400), Y_test shape: (288, 20)\n",
      "Normalized Training Set Mean and SD: [ 3.61063417e-16  6.39438348e-17 -1.02300498e-16 -1.30875249e-16\n",
      "  1.50487262e-16  9.29040795e-17 -9.55061647e-17  3.56582048e-17\n",
      "  2.13852855e-16  1.80796736e-16 -8.88563914e-17 -2.37078875e-17\n",
      "  1.83109700e-18  1.76749047e-16 -8.31703533e-17 -1.51113689e-16\n",
      " -9.14102898e-17 -6.17995238e-16 -2.22646939e-16 -2.82205779e-16\n",
      " -1.82266432e-16 -8.02068674e-17  3.11623798e-16  1.74917950e-16\n",
      " -1.46295013e-16 -8.59651856e-17  4.64520397e-17 -5.58966453e-18\n",
      " -1.00228468e-17 -8.98201266e-17  3.26706255e-16 -2.67918404e-17\n",
      " -5.85951041e-17 -1.41958204e-16 -1.05047144e-16  2.99721667e-17\n",
      "  3.18610878e-16  1.94722710e-16 -3.84482184e-16 -1.36127606e-17\n",
      " -8.77480958e-17  2.05034678e-16 -1.37332275e-16  2.08070444e-16\n",
      "  3.07431549e-17 -4.31753398e-17  7.55568447e-17  5.78241159e-19\n",
      " -7.45931095e-17  5.74386218e-17  3.46944695e-18  9.61807794e-17\n",
      "  1.27213055e-17  2.06817588e-16 -1.13817135e-16  7.74843153e-17\n",
      "  6.23536716e-17  3.47908430e-16  1.66822574e-16 -1.78628331e-16\n",
      " -3.37307343e-19 -6.64977332e-18 -1.83109700e-18 -6.10044422e-17\n",
      "  1.45909519e-16  2.87193109e-17 -5.14634631e-17  5.97515864e-17\n",
      "  1.24803717e-17  7.20392110e-17  1.63353127e-17  1.67593562e-16\n",
      "  8.09537622e-17 -8.33149136e-17  4.45245692e-17 -1.74532456e-16\n",
      " -1.39741613e-17 -4.26934722e-17  1.22201632e-16  1.36850408e-17\n",
      " -1.98047597e-17 -1.03119673e-17 -1.29333272e-16  4.95359926e-17\n",
      " -1.34922937e-18  2.67918404e-17 -1.36850408e-17 -1.27213055e-17\n",
      "  8.07610152e-17 -1.46005893e-17  4.18140638e-17  7.46894830e-17\n",
      "  2.82856300e-17 -1.71906278e-17  2.05757479e-17 -4.79940162e-17\n",
      " -1.42536446e-16  3.31524931e-17 -9.65662735e-17 -6.74614685e-18\n",
      "  7.60387124e-17  3.55618313e-17  1.28369537e-16  4.17297369e-17\n",
      " -6.39920216e-17  2.87193109e-17 -5.47687739e-17 -2.02023005e-17\n",
      "  6.29800995e-17 -1.33621894e-16  2.02890367e-16  3.85494106e-18\n",
      " -5.54147777e-17  9.86864911e-17  1.46728694e-17 -7.88546263e-17\n",
      " -8.91455120e-17  1.11986038e-16  1.24514596e-16  2.08166817e-17\n",
      " -1.34344696e-16  7.42076154e-17  4.97287396e-17 -8.98201266e-17\n",
      "  1.19503173e-17  4.41872619e-17 -1.69231912e-16  7.77734358e-17\n",
      " -1.15840979e-16 -1.19093585e-16 -3.62364459e-17  4.51991839e-17\n",
      " -3.74170216e-17  5.10779690e-17 -8.88563914e-17  1.13913508e-16\n",
      " -7.28824794e-17  7.94117858e-17  1.69713780e-16  7.12561761e-17\n",
      " -1.83495194e-16  8.99165002e-17 -1.06396373e-16  5.78241159e-19\n",
      " -7.13585730e-17  6.36125508e-17  1.99300453e-16 -1.94192656e-16\n",
      "  1.36272166e-16  1.19960947e-16 -2.65027198e-17 -1.34055575e-16\n",
      " -2.12214505e-16  3.22851314e-17 -9.71445147e-17  3.08395285e-17\n",
      " -1.18202130e-16 -1.22105258e-16  8.14597232e-17 -3.30657569e-16\n",
      "  2.35633272e-16  7.93154123e-17  1.33959202e-16 -1.48415231e-17\n",
      " -7.41594286e-17  4.04768811e-18  1.55354125e-16 -3.04540344e-17\n",
      " -2.27056028e-16  8.13392563e-17 -5.70531277e-17 -1.48029737e-16\n",
      " -2.21610924e-16 -1.45524025e-16  5.84023570e-17  7.33402536e-17\n",
      " -2.46234360e-17 -1.26731187e-17  5.54147777e-18 -2.05564732e-16\n",
      "  1.44560290e-17 -1.07167361e-16 -9.30727332e-17 -9.22294648e-17\n",
      "  8.86154576e-17 -2.75146418e-17  2.17804170e-17  1.71484644e-16\n",
      " -1.58052583e-16 -1.89855847e-17 -1.69617407e-16 -6.66904803e-17\n",
      "  6.56544649e-17 -2.45752492e-16  1.11504170e-16 -9.31209199e-18\n",
      "  4.13683362e-17 -3.35861740e-17 -6.32210333e-17 -1.64124116e-16\n",
      " -4.98829373e-16  1.47969503e-16 -1.02155938e-16 -6.70759744e-17\n",
      "  4.68375339e-17  2.55582592e-16  1.50342701e-16  2.10865276e-16\n",
      "  1.32995466e-16  1.64027742e-16  5.12707161e-17  2.47198095e-16\n",
      " -1.22683499e-16 -1.77760970e-16  1.45343325e-17  5.73422482e-18\n",
      "  1.22731686e-16  1.82145965e-17  1.20852402e-16  1.87832003e-16\n",
      "  1.43235154e-17  4.69146327e-16 -1.48222484e-16 -1.34922937e-18\n",
      " -9.44460559e-18  1.46969628e-16 -6.26427922e-17 -1.59690933e-16\n",
      "  1.19503173e-16  1.06974614e-16 -4.49100633e-17 -1.26827561e-16\n",
      "  7.92431321e-17  4.88372845e-17 -5.79204894e-17  6.88106979e-17\n",
      " -3.73929283e-17  3.78747959e-17  9.84937440e-17  1.53041160e-16\n",
      " -8.01827740e-17 -3.15334179e-16  1.96023753e-16 -6.55339980e-18\n",
      "  2.69845874e-17  5.14634631e-17 -1.07938350e-16  1.00421215e-16\n",
      " -1.57137035e-16  1.20177787e-16 -1.47451495e-17  3.25742519e-17\n",
      " -2.19008839e-16 -9.24222119e-17  6.99671802e-17 -1.22683499e-16\n",
      "  5.37764278e-17 -7.13164096e-17 -2.82374432e-17  3.17647143e-16\n",
      "  8.76999091e-17 -1.43596554e-17 -8.04718946e-18 -1.98047597e-17\n",
      " -6.46666362e-17  2.26477787e-17  1.37669583e-16  1.45861332e-16\n",
      "  1.31357117e-16 -1.90241341e-16 -1.40705349e-17 -2.12118132e-16\n",
      "  6.29800995e-17  6.68832274e-17  6.47630098e-17 -1.25671078e-16\n",
      "  1.25285584e-17 -6.34137804e-17  1.44560290e-16  1.01866817e-16\n",
      "  2.59437533e-16  3.60436989e-17  1.81579771e-16  2.55389845e-17\n",
      "  1.74508363e-16 -6.78228692e-17 -4.44763825e-17  1.24514596e-16\n",
      "  5.41619219e-17 -1.18732185e-16  1.84651677e-16 -1.65961236e-16\n",
      "  5.12707161e-17  3.46944695e-17 -8.94346325e-17  7.09309155e-17\n",
      "  5.46437895e-17  2.29368993e-16  2.37271622e-16  1.36802221e-16\n",
      "  3.48486672e-16 -1.28755031e-16 -1.41669084e-16  1.80459428e-17\n",
      "  2.86108907e-17 -5.11382025e-17 -3.91276517e-17 -1.54968631e-16\n",
      " -8.79890296e-17  1.67689936e-17 -7.36293742e-17  3.66219400e-18\n",
      "  3.83566635e-17 -6.86179508e-17 -3.48872166e-17 -2.14912964e-17\n",
      " -6.77505891e-17 -1.53041160e-16 -1.09914007e-16  2.29128059e-16\n",
      "  3.05263145e-16  5.62821394e-17 -2.00697869e-16  1.48415231e-17\n",
      "  1.81182230e-17  1.69617407e-17  5.35836807e-17 -5.49329101e-17\n",
      " -1.05336264e-16  7.38221213e-17  1.18153943e-16  8.35558474e-17\n",
      " -1.29140525e-17 -4.99214867e-17 -9.75300088e-17 -6.65941068e-17\n",
      "  1.16611967e-17  7.52677242e-17  4.41872619e-17  1.89133046e-17\n",
      " -1.31284836e-16 -1.17142021e-16  6.72687215e-17  6.36065275e-18\n",
      "  1.25285584e-17 -1.06010879e-17  6.07153217e-17  8.01827740e-17\n",
      " -1.69810154e-16 -1.24129102e-16 -3.37307343e-17 -1.43018313e-16\n",
      "  1.23358114e-16 -1.64605983e-16 -1.18539438e-17  9.60844059e-17\n",
      "  8.78926561e-17 -1.10299501e-16 -2.28887125e-18 -1.94216749e-16\n",
      " -2.06552561e-16 -1.62654419e-16  8.38208746e-17  1.01192203e-17\n",
      "  9.42533089e-17 -1.21430643e-17  1.94289029e-16 -1.21816137e-16\n",
      "  6.74614685e-18 -1.58438077e-16 -8.59651856e-17 -1.90048594e-16\n",
      "  2.80254215e-16  1.29911514e-16  4.37535810e-17  3.65255665e-17\n",
      "  1.53715775e-17 -4.53678376e-17 -6.13176562e-17 -2.23393834e-16\n",
      "  1.85615412e-16 -1.35886672e-17  7.82553035e-17  5.85951041e-17\n",
      " -1.92650679e-16 -1.33670081e-16 -1.94770897e-16 -2.86518494e-16\n",
      "  2.31585584e-16 -7.17982772e-17 -2.95866726e-17 -1.82145965e-16\n",
      "  1.90434088e-16 -4.45245692e-17 -2.67436536e-17  5.69567541e-17\n",
      "  3.33886082e-16  1.03505167e-16  1.41910018e-16  3.88192565e-16] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (1152, 400)\n",
      "Train labels shape:  (1152, 20)\n",
      "Test data shape:  (288, 400)\n",
      "Test labels shape:  (288, 20)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from PIL import Image\n",
    "import tensorflow as tf\n",
    "\n",
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def load_coil(one_hot=True):\n",
    "    coil20_base_path = os.path.join('data', 'coil-20', 'coil-20-proc')\n",
    "\n",
    "    samples = []\n",
    "    for i in range(1, 21):\n",
    "        for image_index in range(72):\n",
    "            image_path = os.path.join(coil20_base_path, 'obj%d__%d.png' % (i, image_index))\n",
    "            obj_img = Image.open(image_path)\n",
    "            rescaled = obj_img.resize((20, 20))\n",
    "            pixels_values = [float(x) for x in list(rescaled.getdata())]\n",
    "            sample = np.array(pixels_values + [i])\n",
    "            samples.append(sample)\n",
    "            \n",
    "    # Set seed\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "\n",
    "    samples = np.array(samples)\n",
    "    np.random.shuffle(samples)\n",
    "    data = samples[:, :-1]\n",
    "    targets = (samples[:, -1] + 0.5).astype(np.int64) - 1\n",
    "    #data = (data - data.min()) / (data.max() - data.min())\n",
    "\n",
    "    num_classes = np.unique(targets).shape[0]\n",
    "\n",
    "    split_index = data.shape[0] * 4 // 5\n",
    "    X_train, Y_train = data[:split_index], targets[:split_index]\n",
    "    X_test, Y_test = data[split_index:], targets[split_index:]\n",
    "    \n",
    "    # normalize data (new lines)\n",
    "    scaler = StandardScaler().fit(X_train)\n",
    "    X_train = scaler.transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    if one_hot:\n",
    "        Y_train = one_hot_encode(Y_train, num_classes)\n",
    "        Y_test = one_hot_encode(Y_test, num_classes)\n",
    "\n",
    "    print(\"X_train shape: {}, Y_train shape: {}\".format(X_train.shape, Y_train.shape))\n",
    "    print(\"X_test shape: {}, Y_test shape: {}\".format(X_test.shape, Y_test.shape))\n",
    "\n",
    "    return (X_train, Y_train), (X_test, Y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_coil()\n",
    "\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-19T04:02:53.164262Z",
     "iopub.status.busy": "2025-01-19T04:02:53.163479Z",
     "iopub.status.idle": "2025-01-19T04:02:58.955686Z",
     "shell.execute_reply": "2025-01-19T04:02:58.954698Z",
     "shell.execute_reply.started": "2025-01-19T04:02:53.164237Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 400)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               120300    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 20)                2020      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 152,420\n",
      "Trainable params: 152,420\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 2.000e-01\n",
      "Epoch 1/100\n",
      "5/5 [==============================] - 2s 3ms/step - loss: 2.1694 - accuracy: 0.4575\n",
      "\n",
      "Epoch 2: Current learning rate = 1.999e-01\n",
      "Epoch 2/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 0.4074 - accuracy: 0.8715\n",
      "\n",
      "Epoch 3: Current learning rate = 1.997e-01\n",
      "Epoch 3/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 0.0751 - accuracy: 0.9766\n",
      "\n",
      "Epoch 4: Current learning rate = 1.993e-01\n",
      "Epoch 4/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 0.0061 - accuracy: 1.0000\n",
      "\n",
      "Epoch 5: Current learning rate = 1.988e-01\n",
      "Epoch 5/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 0.0016 - accuracy: 1.0000\n",
      "\n",
      "Epoch 6: Current learning rate = 1.981e-01\n",
      "Epoch 6/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 5.6204e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 7: Current learning rate = 1.972e-01\n",
      "Epoch 7/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 2.8853e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 8: Current learning rate = 1.962e-01\n",
      "Epoch 8/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 1.7422e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 9: Current learning rate = 1.951e-01\n",
      "Epoch 9/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 1.2635e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 10: Current learning rate = 1.938e-01\n",
      "Epoch 10/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 9.8463e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 11: Current learning rate = 1.924e-01\n",
      "Epoch 11/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 8.5010e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 12: Current learning rate = 1.908e-01\n",
      "Epoch 12/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 7.4978e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 13: Current learning rate = 1.891e-01\n",
      "Epoch 13/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 6.8689e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 14: Current learning rate = 1.872e-01\n",
      "Epoch 14/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 6.4035e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 15: Current learning rate = 1.853e-01\n",
      "Epoch 15/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 6.0667e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 16: Current learning rate = 1.831e-01\n",
      "Epoch 16/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 5.7917e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 17: Current learning rate = 1.809e-01\n",
      "Epoch 17/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 5.5752e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 18: Current learning rate = 1.785e-01\n",
      "Epoch 18/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 5.3744e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 19: Current learning rate = 1.760e-01\n",
      "Epoch 19/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 5.2115e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 20: Current learning rate = 1.734e-01\n",
      "Epoch 20/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 5.0720e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 21: Current learning rate = 1.707e-01\n",
      "Epoch 21/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.9361e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 22: Current learning rate = 1.679e-01\n",
      "Epoch 22/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.8247e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 23: Current learning rate = 1.649e-01\n",
      "Epoch 23/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.7149e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 24: Current learning rate = 1.619e-01\n",
      "Epoch 24/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 4.6162e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 25: Current learning rate = 1.588e-01\n",
      "Epoch 25/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 4.5170e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 26: Current learning rate = 1.556e-01\n",
      "Epoch 26/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.4328e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 27: Current learning rate = 1.522e-01\n",
      "Epoch 27/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.3555e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 28: Current learning rate = 1.489e-01\n",
      "Epoch 28/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.2784e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 29: Current learning rate = 1.454e-01\n",
      "Epoch 29/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 4.2067e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 30: Current learning rate = 1.419e-01\n",
      "Epoch 30/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 4.1459e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 31: Current learning rate = 1.383e-01\n",
      "Epoch 31/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 4.0821e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 32: Current learning rate = 1.346e-01\n",
      "Epoch 32/100\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 4.0218e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 33: Current learning rate = 1.309e-01\n",
      "Epoch 33/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.9641e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 34: Current learning rate = 1.271e-01\n",
      "Epoch 34/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.9102e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 35: Current learning rate = 1.233e-01\n",
      "Epoch 35/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.8631e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 36: Current learning rate = 1.195e-01\n",
      "Epoch 36/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.8162e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 37: Current learning rate = 1.156e-01\n",
      "Epoch 37/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.7693e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 38: Current learning rate = 1.118e-01\n",
      "Epoch 38/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.7281e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 39: Current learning rate = 1.078e-01\n",
      "Epoch 39/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.6883e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 40: Current learning rate = 1.039e-01\n",
      "Epoch 40/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.6511e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 41: Current learning rate = 1.000e-01\n",
      "Epoch 41/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.6174e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 42: Current learning rate = 9.607e-02\n",
      "Epoch 42/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.5858e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 43: Current learning rate = 9.215e-02\n",
      "Epoch 43/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.5517e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 44: Current learning rate = 8.825e-02\n",
      "Epoch 44/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.5226e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 45: Current learning rate = 8.436e-02\n",
      "Epoch 45/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.4962e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 46: Current learning rate = 8.049e-02\n",
      "Epoch 46/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.4683e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 47: Current learning rate = 7.666e-02\n",
      "Epoch 47/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.4449e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 48: Current learning rate = 7.286e-02\n",
      "Epoch 48/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.4195e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 49: Current learning rate = 6.910e-02\n",
      "Epoch 49/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.3997e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 50: Current learning rate = 6.539e-02\n",
      "Epoch 50/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.3780e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 51: Current learning rate = 6.173e-02\n",
      "Epoch 51/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.3597e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 52: Current learning rate = 5.813e-02\n",
      "Epoch 52/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.3423e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 53: Current learning rate = 5.460e-02\n",
      "Epoch 53/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.3260e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 54: Current learning rate = 5.114e-02\n",
      "Epoch 54/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.3120e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 55: Current learning rate = 4.775e-02\n",
      "Epoch 55/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2992e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 56: Current learning rate = 4.444e-02\n",
      "Epoch 56/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2848e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 57: Current learning rate = 4.122e-02\n",
      "Epoch 57/100\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 3.2734e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 58: Current learning rate = 3.809e-02\n",
      "Epoch 58/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2632e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 59: Current learning rate = 3.506e-02\n",
      "Epoch 59/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2527e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 60: Current learning rate = 3.212e-02\n",
      "Epoch 60/100\n",
      "5/5 [==============================] - 0s 5ms/step - loss: 3.2435e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 61: Current learning rate = 2.929e-02\n",
      "Epoch 61/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2353e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 62: Current learning rate = 2.657e-02\n",
      "Epoch 62/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.2275e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 63: Current learning rate = 2.396e-02\n",
      "Epoch 63/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.2206e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 64: Current learning rate = 2.147e-02\n",
      "Epoch 64/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.2141e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 65: Current learning rate = 1.910e-02\n",
      "Epoch 65/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2079e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 66: Current learning rate = 1.685e-02\n",
      "Epoch 66/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.2031e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 67: Current learning rate = 1.474e-02\n",
      "Epoch 67/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1984e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 68: Current learning rate = 1.275e-02\n",
      "Epoch 68/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1943e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 69: Current learning rate = 1.090e-02\n",
      "Epoch 69/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1906e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 70: Current learning rate = 9.186e-03\n",
      "Epoch 70/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1878e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 71: Current learning rate = 7.612e-03\n",
      "Epoch 71/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1851e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 72: Current learning rate = 6.181e-03\n",
      "Epoch 72/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1827e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 73: Current learning rate = 4.894e-03\n",
      "Epoch 73/100\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 3.1811e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 74: Current learning rate = 3.754e-03\n",
      "Epoch 74/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1792e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 75: Current learning rate = 2.763e-03\n",
      "Epoch 75/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1783e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 76: Current learning rate = 1.921e-03\n",
      "Epoch 76/100\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 3.1773e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 77: Current learning rate = 1.231e-03\n",
      "Epoch 77/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1765e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 78: Current learning rate = 6.932e-04\n",
      "Epoch 78/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1759e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 79: Current learning rate = 3.083e-04\n",
      "Epoch 79/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1754e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 80: Current learning rate = 7.710e-05\n",
      "Epoch 80/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1752e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 81: Current learning rate = 0.000e+00\n",
      "Epoch 81/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1751e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 82: Current learning rate = 0.000e+00\n",
      "Epoch 82/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1751e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 83: Current learning rate = 0.000e+00\n",
      "Epoch 83/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1751e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 84: Current learning rate = 0.000e+00\n",
      "Epoch 84/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1751e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 85: Current learning rate = 0.000e+00\n",
      "Epoch 85/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 86: Current learning rate = 0.000e+00\n",
      "Epoch 86/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 87: Current learning rate = 0.000e+00\n",
      "Epoch 87/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 88: Current learning rate = 0.000e+00\n",
      "Epoch 88/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 89: Current learning rate = 0.000e+00\n",
      "Epoch 89/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 90: Current learning rate = 0.000e+00\n",
      "Epoch 90/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 91: Current learning rate = 0.000e+00\n",
      "Epoch 91/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 92: Current learning rate = 0.000e+00\n",
      "Epoch 92/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 93: Current learning rate = 0.000e+00\n",
      "Epoch 93/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 94: Current learning rate = 0.000e+00\n",
      "Epoch 94/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 95: Current learning rate = 0.000e+00\n",
      "Epoch 95/100\n",
      "5/5 [==============================] - 0s 2ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 96: Current learning rate = 0.000e+00\n",
      "Epoch 96/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "5/5 [==============================] - 0s 3ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "5/5 [==============================] - 0s 7ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "5/5 [==============================] - 0s 8ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "5/5 [==============================] - 0s 6ms/step - loss: 3.1750e-05 - accuracy: 1.0000\n",
      "9/9 [==============================] - 0s 4ms/step - loss: 0.0435 - accuracy: 0.9896\n",
      "\n",
      "Test loss 0.043488599359989166\n",
      "Test accuracy 0.9895833134651184\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on COIL-20\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_coil20'\n",
    "#DEPTH = DEPTH\n",
    "LA = 0 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR #INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "EPOCHS = EPOCHS\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T04:03:35.249287Z",
     "iopub.status.busy": "2025-01-19T04:03:35.248902Z",
     "iopub.status.idle": "2025-01-19T04:22:06.461854Z",
     "shell.execute_reply": "2025-01-19T04:22:06.461266Z",
     "shell.execute_reply.started": "2025-01-19T04:03:35.249255Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED  # Store original seed value\n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_coil20'\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            \n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "            \n",
    "            # Check if results already exist\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            if os.path.exists(pretrain_csv_file_path):\n",
    "                print(f'Results already exist for depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}. Skipping...')\n",
    "                continue\n",
    "                \n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            \n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                       init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                                  dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                                  large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                       loss='categorical_crossentropy',\n",
    "                       metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                       callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                                'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "            \n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T04:25:14.604843Z",
     "iopub.status.busy": "2025-01-19T04:25:14.604200Z",
     "iopub.status.idle": "2025-01-19T04:25:14.985677Z",
     "shell.execute_reply": "2025-01-19T04:25:14.985100Z",
     "shell.execute_reply.started": "2025-01-19T04:25:14.604780Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train shape: (1152, 400), Y_train shape: (1152,)\n",
      "X_test shape: (288, 400), Y_test shape: (288,)\n",
      "Normalized Training Set Mean and SD: [ 3.61063417e-16  6.39438348e-17 -1.02300498e-16 -1.30875249e-16\n",
      "  1.50487262e-16  9.29040795e-17 -9.55061647e-17  3.56582048e-17\n",
      "  2.13852855e-16  1.80796736e-16 -8.88563914e-17 -2.37078875e-17\n",
      "  1.83109700e-18  1.76749047e-16 -8.31703533e-17 -1.51113689e-16\n",
      " -9.14102898e-17 -6.17995238e-16 -2.22646939e-16 -2.82205779e-16\n",
      " -1.82266432e-16 -8.02068674e-17  3.11623798e-16  1.74917950e-16\n",
      " -1.46295013e-16 -8.59651856e-17  4.64520397e-17 -5.58966453e-18\n",
      " -1.00228468e-17 -8.98201266e-17  3.26706255e-16 -2.67918404e-17\n",
      " -5.85951041e-17 -1.41958204e-16 -1.05047144e-16  2.99721667e-17\n",
      "  3.18610878e-16  1.94722710e-16 -3.84482184e-16 -1.36127606e-17\n",
      " -8.77480958e-17  2.05034678e-16 -1.37332275e-16  2.08070444e-16\n",
      "  3.07431549e-17 -4.31753398e-17  7.55568447e-17  5.78241159e-19\n",
      " -7.45931095e-17  5.74386218e-17  3.46944695e-18  9.61807794e-17\n",
      "  1.27213055e-17  2.06817588e-16 -1.13817135e-16  7.74843153e-17\n",
      "  6.23536716e-17  3.47908430e-16  1.66822574e-16 -1.78628331e-16\n",
      " -3.37307343e-19 -6.64977332e-18 -1.83109700e-18 -6.10044422e-17\n",
      "  1.45909519e-16  2.87193109e-17 -5.14634631e-17  5.97515864e-17\n",
      "  1.24803717e-17  7.20392110e-17  1.63353127e-17  1.67593562e-16\n",
      "  8.09537622e-17 -8.33149136e-17  4.45245692e-17 -1.74532456e-16\n",
      " -1.39741613e-17 -4.26934722e-17  1.22201632e-16  1.36850408e-17\n",
      " -1.98047597e-17 -1.03119673e-17 -1.29333272e-16  4.95359926e-17\n",
      " -1.34922937e-18  2.67918404e-17 -1.36850408e-17 -1.27213055e-17\n",
      "  8.07610152e-17 -1.46005893e-17  4.18140638e-17  7.46894830e-17\n",
      "  2.82856300e-17 -1.71906278e-17  2.05757479e-17 -4.79940162e-17\n",
      " -1.42536446e-16  3.31524931e-17 -9.65662735e-17 -6.74614685e-18\n",
      "  7.60387124e-17  3.55618313e-17  1.28369537e-16  4.17297369e-17\n",
      " -6.39920216e-17  2.87193109e-17 -5.47687739e-17 -2.02023005e-17\n",
      "  6.29800995e-17 -1.33621894e-16  2.02890367e-16  3.85494106e-18\n",
      " -5.54147777e-17  9.86864911e-17  1.46728694e-17 -7.88546263e-17\n",
      " -8.91455120e-17  1.11986038e-16  1.24514596e-16  2.08166817e-17\n",
      " -1.34344696e-16  7.42076154e-17  4.97287396e-17 -8.98201266e-17\n",
      "  1.19503173e-17  4.41872619e-17 -1.69231912e-16  7.77734358e-17\n",
      " -1.15840979e-16 -1.19093585e-16 -3.62364459e-17  4.51991839e-17\n",
      " -3.74170216e-17  5.10779690e-17 -8.88563914e-17  1.13913508e-16\n",
      " -7.28824794e-17  7.94117858e-17  1.69713780e-16  7.12561761e-17\n",
      " -1.83495194e-16  8.99165002e-17 -1.06396373e-16  5.78241159e-19\n",
      " -7.13585730e-17  6.36125508e-17  1.99300453e-16 -1.94192656e-16\n",
      "  1.36272166e-16  1.19960947e-16 -2.65027198e-17 -1.34055575e-16\n",
      " -2.12214505e-16  3.22851314e-17 -9.71445147e-17  3.08395285e-17\n",
      " -1.18202130e-16 -1.22105258e-16  8.14597232e-17 -3.30657569e-16\n",
      "  2.35633272e-16  7.93154123e-17  1.33959202e-16 -1.48415231e-17\n",
      " -7.41594286e-17  4.04768811e-18  1.55354125e-16 -3.04540344e-17\n",
      " -2.27056028e-16  8.13392563e-17 -5.70531277e-17 -1.48029737e-16\n",
      " -2.21610924e-16 -1.45524025e-16  5.84023570e-17  7.33402536e-17\n",
      " -2.46234360e-17 -1.26731187e-17  5.54147777e-18 -2.05564732e-16\n",
      "  1.44560290e-17 -1.07167361e-16 -9.30727332e-17 -9.22294648e-17\n",
      "  8.86154576e-17 -2.75146418e-17  2.17804170e-17  1.71484644e-16\n",
      " -1.58052583e-16 -1.89855847e-17 -1.69617407e-16 -6.66904803e-17\n",
      "  6.56544649e-17 -2.45752492e-16  1.11504170e-16 -9.31209199e-18\n",
      "  4.13683362e-17 -3.35861740e-17 -6.32210333e-17 -1.64124116e-16\n",
      " -4.98829373e-16  1.47969503e-16 -1.02155938e-16 -6.70759744e-17\n",
      "  4.68375339e-17  2.55582592e-16  1.50342701e-16  2.10865276e-16\n",
      "  1.32995466e-16  1.64027742e-16  5.12707161e-17  2.47198095e-16\n",
      " -1.22683499e-16 -1.77760970e-16  1.45343325e-17  5.73422482e-18\n",
      "  1.22731686e-16  1.82145965e-17  1.20852402e-16  1.87832003e-16\n",
      "  1.43235154e-17  4.69146327e-16 -1.48222484e-16 -1.34922937e-18\n",
      " -9.44460559e-18  1.46969628e-16 -6.26427922e-17 -1.59690933e-16\n",
      "  1.19503173e-16  1.06974614e-16 -4.49100633e-17 -1.26827561e-16\n",
      "  7.92431321e-17  4.88372845e-17 -5.79204894e-17  6.88106979e-17\n",
      " -3.73929283e-17  3.78747959e-17  9.84937440e-17  1.53041160e-16\n",
      " -8.01827740e-17 -3.15334179e-16  1.96023753e-16 -6.55339980e-18\n",
      "  2.69845874e-17  5.14634631e-17 -1.07938350e-16  1.00421215e-16\n",
      " -1.57137035e-16  1.20177787e-16 -1.47451495e-17  3.25742519e-17\n",
      " -2.19008839e-16 -9.24222119e-17  6.99671802e-17 -1.22683499e-16\n",
      "  5.37764278e-17 -7.13164096e-17 -2.82374432e-17  3.17647143e-16\n",
      "  8.76999091e-17 -1.43596554e-17 -8.04718946e-18 -1.98047597e-17\n",
      " -6.46666362e-17  2.26477787e-17  1.37669583e-16  1.45861332e-16\n",
      "  1.31357117e-16 -1.90241341e-16 -1.40705349e-17 -2.12118132e-16\n",
      "  6.29800995e-17  6.68832274e-17  6.47630098e-17 -1.25671078e-16\n",
      "  1.25285584e-17 -6.34137804e-17  1.44560290e-16  1.01866817e-16\n",
      "  2.59437533e-16  3.60436989e-17  1.81579771e-16  2.55389845e-17\n",
      "  1.74508363e-16 -6.78228692e-17 -4.44763825e-17  1.24514596e-16\n",
      "  5.41619219e-17 -1.18732185e-16  1.84651677e-16 -1.65961236e-16\n",
      "  5.12707161e-17  3.46944695e-17 -8.94346325e-17  7.09309155e-17\n",
      "  5.46437895e-17  2.29368993e-16  2.37271622e-16  1.36802221e-16\n",
      "  3.48486672e-16 -1.28755031e-16 -1.41669084e-16  1.80459428e-17\n",
      "  2.86108907e-17 -5.11382025e-17 -3.91276517e-17 -1.54968631e-16\n",
      " -8.79890296e-17  1.67689936e-17 -7.36293742e-17  3.66219400e-18\n",
      "  3.83566635e-17 -6.86179508e-17 -3.48872166e-17 -2.14912964e-17\n",
      " -6.77505891e-17 -1.53041160e-16 -1.09914007e-16  2.29128059e-16\n",
      "  3.05263145e-16  5.62821394e-17 -2.00697869e-16  1.48415231e-17\n",
      "  1.81182230e-17  1.69617407e-17  5.35836807e-17 -5.49329101e-17\n",
      " -1.05336264e-16  7.38221213e-17  1.18153943e-16  8.35558474e-17\n",
      " -1.29140525e-17 -4.99214867e-17 -9.75300088e-17 -6.65941068e-17\n",
      "  1.16611967e-17  7.52677242e-17  4.41872619e-17  1.89133046e-17\n",
      " -1.31284836e-16 -1.17142021e-16  6.72687215e-17  6.36065275e-18\n",
      "  1.25285584e-17 -1.06010879e-17  6.07153217e-17  8.01827740e-17\n",
      " -1.69810154e-16 -1.24129102e-16 -3.37307343e-17 -1.43018313e-16\n",
      "  1.23358114e-16 -1.64605983e-16 -1.18539438e-17  9.60844059e-17\n",
      "  8.78926561e-17 -1.10299501e-16 -2.28887125e-18 -1.94216749e-16\n",
      " -2.06552561e-16 -1.62654419e-16  8.38208746e-17  1.01192203e-17\n",
      "  9.42533089e-17 -1.21430643e-17  1.94289029e-16 -1.21816137e-16\n",
      "  6.74614685e-18 -1.58438077e-16 -8.59651856e-17 -1.90048594e-16\n",
      "  2.80254215e-16  1.29911514e-16  4.37535810e-17  3.65255665e-17\n",
      "  1.53715775e-17 -4.53678376e-17 -6.13176562e-17 -2.23393834e-16\n",
      "  1.85615412e-16 -1.35886672e-17  7.82553035e-17  5.85951041e-17\n",
      " -1.92650679e-16 -1.33670081e-16 -1.94770897e-16 -2.86518494e-16\n",
      "  2.31585584e-16 -7.17982772e-17 -2.95866726e-17 -1.82145965e-16\n",
      "  1.90434088e-16 -4.45245692e-17 -2.67436536e-17  5.69567541e-17\n",
      "  3.33886082e-16  1.03505167e-16  1.41910018e-16  3.88192565e-16] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
      " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "Train data shape:  (1152, 400)\n",
      "Train labels shape:  (1152,)\n",
      "Test data shape:  (288, 400)\n",
      "Test labels shape:  (288,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_coil(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T04:25:44.791084Z",
     "iopub.status.busy": "2025-01-19T04:25:44.790346Z",
     "iopub.status.idle": "2025-01-19T05:21:37.674024Z",
     "shell.execute_reply": "2025-01-19T05:21:37.673288Z",
     "shell.execute_reply.started": "2025-01-19T04:25:44.791058Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_svm, hsic_dnn, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/coil20/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False),\n",
    "        (\"LassoNet\", lassoNet, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        # Check if results already exist\n",
    "        if os.path.exists(result_filename):\n",
    "            print(f\"Results already exist for {method_name} on {dataset_name} (repetition {rep + 1}). Skipping...\")\n",
    "            continue\n",
    "            \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"Coil-20\": load_coil\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'coil20')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
