{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T10:38:08.033273Z",
     "iopub.status.busy": "2025-01-20T10:38:08.032623Z",
     "iopub.status.idle": "2025-01-20T10:38:08.042416Z",
     "shell.execute_reply": "2025-01-20T10:38:08.041781Z",
     "shell.execute_reply.started": "2025-01-20T10:38:08.033244Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn,hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.1 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 10\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/MNIST'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    #1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    #4e-4,\n",
    "    5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    1.25e-1,\n",
    "    1.5e-1,\n",
    "    1.75e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T10:38:12.341074Z",
     "iopub.status.busy": "2025-01-20T10:38:12.340168Z",
     "iopub.status.idle": "2025-01-20T10:38:13.378672Z",
     "shell.execute_reply": "2025-01-20T10:38:13.377994Z",
     "shell.execute_reply.started": "2025-01-20T10:38:12.341032Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (60000, 784), y_train shape: (60000, 10)\n",
      "x_test shape: (10000, 784), y_test shape: (10000, 10)\n",
      "Normalized Training Set Mean and SD: [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -4.41811746e-03 -5.75460540e-03 -4.08261409e-03 -4.08261409e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -4.08261409e-03 -4.70944401e-03 -8.79918039e-03 -1.15902880e-02\n",
      " -4.10827546e-04  7.34113948e-03  1.62988584e-02  4.16030996e-02\n",
      "  2.12192275e-02 -2.18368624e-03  1.40861133e-02  2.34294031e-02\n",
      " -1.62362196e-02 -6.72437530e-03  1.82735343e-02  1.46022588e-02\n",
      " -1.81332615e-03 -1.09966453e-02 -8.32507201e-03 -4.38052649e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -4.08261409e-03 -5.39535051e-03\n",
      " -8.52213986e-03  1.62882105e-01  2.82596853e-02 -6.80418860e-04\n",
      " -2.75809946e-03  2.19389424e-02  5.55645563e-02  9.38820988e-02\n",
      "  7.39584491e-02  3.26009393e-02  1.77913085e-02  1.09734777e-02\n",
      " -5.40107302e-03 -6.18119864e-03 -1.37357162e-02 -2.40913182e-02\n",
      " -1.69764850e-02 -4.40417323e-03  1.58787761e-02  1.13025401e-02\n",
      " -1.35878371e-02 -7.83111621e-03  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -5.36850980e-03 -8.87078419e-03\n",
      "  4.95642005e-03  1.49929868e-02  3.77630778e-02  1.74969789e-02\n",
      "  3.17634605e-02  4.31552008e-02  5.79657406e-02  6.51872158e-02\n",
      "  6.95984289e-02  5.02408594e-02  1.19145364e-02 -4.23975894e-03\n",
      " -1.69138461e-02 -3.21586989e-02 -3.33180428e-02 -2.81243585e-02\n",
      " -2.04585288e-02 -2.81097014e-02 -5.48643339e-03 -1.78614762e-04\n",
      " -2.01879907e-03  6.57787779e-03 -5.28352521e-03  0.00000000e+00\n",
      "  0.00000000e+00 -4.08261409e-03 -7.76327914e-03  1.49458200e-02\n",
      " -4.18766495e-03 -2.23613940e-02  1.07237101e-02  1.90405026e-02\n",
      "  1.96592584e-02  3.35484110e-02  4.25123088e-02  5.26890568e-02\n",
      "  4.26664837e-02  3.57930884e-02  2.34533530e-02  1.10070463e-02\n",
      " -4.27128980e-04 -1.77778229e-02 -2.88259368e-02 -1.89742204e-02\n",
      " -2.28149947e-02 -2.75410265e-02 -1.10654766e-02 -1.60864610e-02\n",
      " -1.03070363e-02  1.73480459e-03 -4.03776811e-03 -5.77035174e-03\n",
      "  0.00000000e+00  0.00000000e+00 -1.16134398e-02 -1.68821972e-03\n",
      "  1.11869664e-03 -1.14550609e-02 -4.49247978e-04  9.97761637e-03\n",
      "  1.84000488e-02  2.26127096e-02  4.90894876e-02  5.98350950e-02\n",
      "  4.87283096e-02  3.02107446e-02  1.80491935e-02  3.70267918e-03\n",
      " -5.76877641e-03 -1.11815091e-02 -2.53235511e-02 -1.95845552e-02\n",
      " -7.30251940e-03 -9.80170164e-03 -7.79611943e-03 -2.34153904e-02\n",
      " -1.20436801e-02 -6.06956286e-03  1.67097524e-02 -8.20913538e-03\n",
      "  0.00000000e+00 -5.57001401e-03 -1.38114160e-02  6.78550918e-03\n",
      "  1.81221552e-02  1.24981825e-03 -2.06548866e-04  5.56156458e-03\n",
      "  1.55270230e-02  2.74176374e-02  3.63383070e-02  5.49860373e-02\n",
      "  5.38944043e-02  3.41811590e-02  3.26609612e-02  1.87016446e-02\n",
      " -2.26639491e-03 -1.42939240e-02 -3.15517560e-02 -3.14056985e-02\n",
      " -2.26257741e-02 -1.79869793e-02 -1.90046374e-02 -2.77752709e-02\n",
      " -8.34583584e-03 -1.02854380e-03 -4.40107239e-03 -5.12118824e-03\n",
      " -4.08261409e-03 -1.32304691e-02 -6.43922621e-03  9.91706457e-03\n",
      "  1.55324936e-02  1.02405651e-02 -3.56938876e-03 -2.94284383e-03\n",
      "  2.36373395e-02  3.48850563e-02  4.46509011e-02  6.21613935e-02\n",
      "  5.10751903e-02  3.94567326e-02  3.63223776e-02  2.79514063e-02\n",
      "  1.21207046e-03 -9.31385811e-03 -2.11941134e-02 -2.53404211e-02\n",
      " -3.30689959e-02 -2.90159546e-02 -3.26997004e-02 -2.45186239e-02\n",
      " -1.41697209e-02 -1.26495296e-02 -1.36567568e-02 -1.46672800e-02\n",
      " -5.02568530e-03  1.00326436e-02  2.42148079e-02 -5.89473685e-03\n",
      " -9.81412828e-04 -3.77764739e-03 -8.52157455e-03 -9.43157263e-03\n",
      "  3.33726108e-02  5.82444258e-02  8.52125958e-02  8.43273252e-02\n",
      "  4.01754603e-02  2.14373358e-02  2.05844082e-02  1.21960789e-02\n",
      "  2.08280236e-03  1.01914816e-03 -1.21372687e-02 -1.45974560e-02\n",
      " -2.44250670e-02 -4.13784347e-02 -3.92613634e-02 -2.44028643e-02\n",
      " -7.86530226e-03 -2.80681998e-03  7.91423954e-03  1.00944173e-02\n",
      "  4.23314497e-02  5.08148149e-02  2.46029440e-02  3.32195847e-03\n",
      " -5.67673612e-03 -1.86863318e-02 -2.41163746e-02 -1.00210020e-02\n",
      "  3.97832133e-02  8.28146338e-02  9.68241468e-02  6.20155148e-02\n",
      "  1.35290902e-02 -1.13764154e-02  4.77549474e-04  6.96658157e-03\n",
      " -1.79711671e-03  5.28488867e-03  2.96694483e-03 -1.09362323e-02\n",
      " -2.19472349e-02 -4.35140692e-02 -4.06785421e-02 -2.23062448e-02\n",
      " -8.44532717e-03 -1.22033171e-02  7.93837418e-04  2.64569279e-03\n",
      "  5.54948375e-02  1.75499450e-02 -7.65063334e-03 -9.57124960e-03\n",
      " -2.32865755e-02 -4.13067453e-02 -3.52517255e-02 -9.08283144e-03\n",
      "  5.00219166e-02  8.82577300e-02  7.83220828e-02  2.71792561e-02\n",
      " -3.13579589e-02 -2.79204641e-02 -3.52647598e-03 -7.16581801e-03\n",
      " -1.06959688e-02  2.03768481e-02  1.82100888e-02  2.20774044e-03\n",
      " -2.87568700e-02 -4.71206009e-02 -3.30201946e-02 -1.18533233e-02\n",
      " -5.77503350e-03 -9.28007439e-03  2.86790095e-02  2.99599227e-02\n",
      "  4.24134592e-03 -1.60132423e-02 -2.00060457e-02 -1.68357547e-02\n",
      " -4.34514955e-02 -4.75979373e-02 -4.17396389e-02  2.22611730e-03\n",
      "  5.65313660e-02  7.84565955e-02  5.14779203e-02 -1.10145723e-02\n",
      " -4.89527397e-02 -1.95959955e-02 -6.03359099e-03  8.12703744e-03\n",
      "  2.26473212e-02  3.49758156e-02  2.34207418e-02  5.19333920e-03\n",
      " -2.11802181e-02 -3.34797017e-02 -2.04111207e-02 -2.00267369e-03\n",
      "  1.11470614e-02  8.99341982e-03  1.19247911e-02 -9.12484061e-03\n",
      " -5.30433701e-03 -1.92702457e-03 -8.15872382e-03 -1.98461860e-02\n",
      " -3.42833698e-02 -3.56188677e-02 -1.28956847e-02  2.85033397e-02\n",
      "  5.72260730e-02  5.27206287e-02  2.25755665e-02 -2.11761035e-02\n",
      " -2.67201364e-02 -7.39592686e-03  1.71533320e-02  3.55378836e-02\n",
      "  4.77375649e-02  3.55996676e-02  2.30031516e-02 -4.26578277e-04\n",
      " -1.03206690e-02 -1.78005826e-02  2.36024684e-03  9.68122855e-03\n",
      "  9.91401076e-03  1.23962658e-02 -7.42188562e-03 -9.92202573e-03\n",
      " -4.08261409e-03  1.19835259e-02 -2.57651857e-03 -2.05178726e-02\n",
      " -2.91534010e-02 -2.35685986e-02 -8.04510026e-04  3.17953601e-02\n",
      "  4.23842520e-02  2.30016951e-02 -3.44195403e-04 -1.59274656e-02\n",
      " -2.66919974e-02  2.50606332e-03  3.53288651e-02  5.48702404e-02\n",
      "  4.80154566e-02  3.17603536e-02  1.46136926e-02  7.22691091e-03\n",
      "  1.46815320e-03  9.95976734e-04  2.21733991e-02  1.00731356e-02\n",
      " -4.49132611e-04  5.13953215e-04 -5.69798611e-03 -7.40432693e-03\n",
      " -4.08261409e-03 -7.98808783e-03 -1.02737295e-02 -2.47155316e-02\n",
      " -2.11008321e-02 -1.12750102e-02 -5.36655961e-03  1.00128651e-02\n",
      "  2.03726329e-02  1.57157816e-02 -2.14449456e-03 -2.62317597e-03\n",
      " -1.22394068e-02  1.02235246e-02  2.15032194e-02  4.45979759e-02\n",
      "  5.13432585e-02  3.08734756e-02  2.17046216e-02  8.45012721e-03\n",
      "  5.36837790e-04  1.40687767e-02  1.85617134e-02  1.49831427e-02\n",
      " -6.24519773e-03 -1.11615434e-02 -1.89122744e-02 -6.98474329e-03\n",
      " -4.77027940e-03 -4.18943726e-03 -8.82868096e-03 -2.19771266e-02\n",
      " -1.61028132e-02 -1.02803065e-02 -1.88928694e-02 -9.32159182e-03\n",
      "  7.68913375e-03  7.92805478e-03  3.35091818e-03  4.89460072e-03\n",
      "  1.41604273e-02  3.22548747e-02  2.37208586e-02  4.35541496e-02\n",
      "  5.52673154e-02  4.82611395e-02  3.02265361e-02  8.08343291e-03\n",
      "  4.17829351e-03  1.73312202e-02  1.30356066e-02  1.29084894e-02\n",
      " -9.83185321e-03 -1.40072182e-02 -2.05306690e-02 -9.10989009e-03\n",
      " -4.08261409e-03 -6.70758402e-03  7.74516317e-04 -2.92797517e-02\n",
      " -1.05634918e-02 -1.29393917e-02 -2.20253803e-02 -1.02981823e-02\n",
      "  1.54633205e-02  1.07212923e-02  3.82813648e-03 -5.31125255e-03\n",
      "  1.65223517e-02  2.07003020e-02  1.05960611e-02  2.88257264e-02\n",
      "  5.85474111e-02  7.48338476e-02  4.65679504e-02  3.13898921e-02\n",
      "  2.67218444e-02  1.58677213e-02  5.66643663e-03 -2.34023505e-03\n",
      " -1.88134275e-02 -5.75260399e-03 -1.68579053e-02  1.03213442e-02\n",
      "  0.00000000e+00 -8.93191434e-03 -3.28481954e-04 -3.06835510e-02\n",
      " -8.66520219e-03  8.77050741e-04 -8.49256199e-03 -1.22461831e-02\n",
      "  1.70802735e-02  2.78926343e-02  2.18792725e-02  2.88548507e-03\n",
      "  2.27689510e-03 -8.12013913e-03 -1.34364022e-02  2.22838409e-02\n",
      "  6.13453612e-02  9.44503173e-02  8.24194998e-02  6.45607412e-02\n",
      "  5.15957437e-02  1.90572832e-02  9.74230189e-03 -2.09910795e-03\n",
      " -1.72413383e-02 -2.04935037e-02 -2.48995908e-02  2.51586698e-02\n",
      " -5.90576651e-03 -6.35195151e-03 -9.69331420e-04 -1.97413545e-02\n",
      " -2.72226613e-03  8.89686926e-04  4.31215204e-03 -3.30640818e-03\n",
      "  2.09884215e-02  2.02501602e-02  1.17556509e-02 -2.22103714e-04\n",
      " -8.37222207e-03 -2.11433414e-02 -1.23987151e-02  3.04329880e-02\n",
      "  7.35105127e-02  9.69756842e-02  8.75395536e-02  8.14764425e-02\n",
      "  4.83968481e-02  1.06824087e-02 -7.33584771e-03 -1.82500910e-02\n",
      " -2.49640122e-02 -1.63379200e-02  1.85905793e-03  1.04673030e-02\n",
      " -4.08261409e-03 -1.12813441e-02 -1.58734061e-02 -1.51830455e-02\n",
      " -1.26617122e-02 -2.98885349e-03  2.13782117e-03  2.01105652e-03\n",
      "  7.97699764e-03  6.29997300e-03  1.13672409e-02  7.83106964e-03\n",
      " -8.15426093e-03 -1.08835557e-02  1.70712508e-02  5.21537550e-02\n",
      "  8.94790068e-02  8.42244551e-02  8.56550932e-02  7.07507357e-02\n",
      "  2.11949311e-02 -1.72263607e-02 -2.65874956e-02 -3.37393843e-02\n",
      " -3.11973970e-02 -1.12663805e-02  6.06848858e-04 -2.15330860e-03\n",
      "  0.00000000e+00 -1.03417365e-02 -1.22197848e-02 -4.27834317e-03\n",
      " -2.16761585e-02 -1.19962459e-02 -1.47948368e-02 -1.12066008e-02\n",
      " -2.72350921e-03 -2.51717190e-03  1.58068519e-02  2.42214389e-02\n",
      "  8.03474430e-03  1.75393503e-02  4.30478565e-02  6.64937571e-02\n",
      "  7.63814449e-02  7.55630657e-02  6.52094632e-02  3.21744233e-02\n",
      " -1.18789412e-02 -3.32951173e-02 -3.46551649e-02 -2.50138585e-02\n",
      " -1.75183099e-02 -1.43489260e-02  4.79382416e-03 -5.77350287e-03\n",
      " -4.08261409e-03 -1.00096930e-02 -1.31540536e-03 -1.46820064e-04\n",
      " -2.28588618e-02 -1.04693556e-02 -1.81319062e-02 -2.35702917e-02\n",
      " -1.85565315e-02 -1.43159144e-02 -6.61694724e-03  4.31766361e-03\n",
      "  9.35346540e-03  3.09117306e-02  5.17731197e-02  5.91272153e-02\n",
      "  6.89778849e-02  6.08396009e-02  3.11856512e-02 -8.34356248e-03\n",
      " -3.37222554e-02 -4.85104211e-02 -3.99011672e-02 -1.69564448e-02\n",
      " -1.30429976e-02 -2.70048063e-02  9.24569438e-04 -4.08261409e-03\n",
      " -4.08261409e-03 -6.00481918e-03 -2.26263311e-02 -2.95542250e-03\n",
      " -2.29217056e-02 -1.13520082e-02 -2.56241206e-02 -2.97484025e-02\n",
      " -3.51864472e-02 -2.72086877e-02 -2.58972049e-02 -9.62943584e-03\n",
      "  4.09017457e-03  2.02904511e-02  4.60916944e-02  4.98231128e-02\n",
      "  5.59843779e-02  2.98551247e-02  1.01009458e-02 -1.98486410e-02\n",
      " -4.24833633e-02 -5.44052608e-02 -3.72422449e-02 -1.98121760e-02\n",
      "  1.24626688e-03  2.08764267e-03 -8.95156339e-03 -4.08261409e-03\n",
      "  0.00000000e+00  0.00000000e+00 -1.91064626e-02 -1.35435145e-02\n",
      " -2.16376688e-02 -1.03856167e-02 -2.45866496e-02 -2.94696223e-02\n",
      " -4.15393040e-02 -6.16575517e-02 -4.36944328e-02 -2.20641773e-02\n",
      "  7.69260200e-03  2.48496085e-02  4.90829125e-02  5.72306849e-02\n",
      "  6.01896644e-02  3.96406464e-02  9.39116348e-03 -2.73344070e-02\n",
      " -4.55973335e-02 -3.86242680e-02 -2.85568126e-02 -1.53164165e-02\n",
      "  4.31055715e-03  2.92128008e-02 -1.04026571e-02  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -1.03530018e-02 -1.68102235e-02\n",
      " -2.24907678e-02 -2.10993234e-02 -3.08284760e-02 -4.58658747e-02\n",
      " -4.31182496e-02 -6.45727068e-02 -5.30257560e-02 -3.24722826e-02\n",
      "  2.10235245e-03  2.13605110e-02  4.91039939e-02  7.43381009e-02\n",
      "  8.82792696e-02  4.60067727e-02 -4.17719921e-03 -2.29650829e-02\n",
      " -3.25202122e-02 -4.37310264e-02 -3.97582501e-02 -3.26538906e-02\n",
      " -1.98085010e-02 -1.08662900e-02 -4.75306250e-03  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -6.66427705e-03 -1.09637678e-02\n",
      " -8.71103373e-04 -2.75413934e-02 -4.61031832e-02 -4.74594682e-02\n",
      " -3.83740030e-02 -3.54633592e-02 -2.97762491e-02  5.76993497e-03\n",
      "  3.57955880e-02  3.72918956e-02  5.51128723e-02  5.37002236e-02\n",
      "  5.57434000e-02  2.00574957e-02 -3.23551223e-02 -4.29426692e-02\n",
      " -2.72130780e-02 -2.38635503e-02 -3.41538191e-02 -2.24017650e-02\n",
      " -2.59548135e-04  2.13750787e-02 -4.08261409e-03  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00 -4.08261409e-03\n",
      " -1.01094292e-02 -1.03547452e-02 -1.63998064e-02 -2.63651311e-02\n",
      " -2.02881917e-02 -2.40834001e-02 -2.56966818e-02  2.03555543e-03\n",
      "  2.51829512e-02  2.63418723e-02  2.14687567e-02  2.21937317e-02\n",
      "  3.37087959e-02  6.75980421e-03 -1.50866285e-02 -1.35353142e-02\n",
      " -1.64796859e-02  1.23606250e-03 -1.14327017e-02  1.73845831e-02\n",
      " -4.49435040e-03 -4.08261409e-03  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -5.79020847e-03 -9.05387662e-03 -1.45641081e-02 -3.70762567e-03\n",
      " -3.07631126e-04 -1.08020976e-02 -3.09229475e-02 -2.82016695e-02\n",
      " -2.04058159e-02 -8.88458174e-03 -1.42098330e-02 -2.73996079e-03\n",
      " -5.47817070e-03 -1.14821615e-02  1.76217553e-04  1.61317724e-03\n",
      " -4.62485012e-03 -1.14325648e-02 -9.00690258e-03 -5.77026419e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00] [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 3.9580062e-08 2.1467318e-07 9.7322207e-08\n",
      " 9.6857548e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 9.7322207e-08 2.4400114e-07 1.6577050e-07\n",
      " 2.7101578e-07 9.0735304e-01 1.0841730e+00 1.2711843e+00 1.5736740e+00\n",
      " 1.2879838e+00 9.6679288e-01 1.1832184e+00 1.4103526e+00 5.5485547e-01\n",
      " 8.7634039e-01 1.4318393e+00 1.4727974e+00 7.9626387e-01 2.8497971e-07\n",
      " 2.1141406e-07 1.6717675e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 9.7322207e-08 4.6566129e-09\n",
      " 2.7194076e-07 7.9728589e+00 1.9015563e+00 1.0075092e+00 8.3693093e-01\n",
      " 1.1488374e+00 1.3370754e+00 1.5042129e+00 1.3405477e+00 1.0911613e+00\n",
      " 1.0712247e+00 1.0505410e+00 9.6034199e-01 9.8247695e-01 9.4710863e-01\n",
      " 8.7330800e-01 8.5679340e-01 9.9351680e-01 1.2876799e+00 1.2477993e+00\n",
      " 3.1622518e-02 2.7567148e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 1.3224560e-07 1.7881393e-07 1.2033213e+00 1.3229245e+00\n",
      " 1.5235522e+00 1.1076982e+00 1.1632768e+00 1.1813263e+00 1.2103313e+00\n",
      " 1.1834729e+00 1.1790432e+00 1.1173617e+00 1.0291021e+00 9.9540633e-01\n",
      " 9.7862196e-01 9.3266672e-01 9.2032367e-01 9.3253803e-01 9.3327069e-01\n",
      " 8.5878128e-01 9.7875196e-01 9.7926056e-01 1.0608642e+00 9.9633288e-01\n",
      " 1.9418022e-07 0.0000000e+00 0.0000000e+00 9.7322207e-08 1.4342014e-07\n",
      " 1.8705932e+00 9.9057370e-01 6.7717135e-01 1.0727829e+00 1.1030509e+00\n",
      " 1.0745533e+00 1.0588750e+00 1.0800353e+00 1.0832589e+00 1.0565811e+00\n",
      " 1.0338585e+00 1.0175095e+00 1.0101744e+00 9.9652171e-01 9.7620130e-01\n",
      " 9.6532804e-01 9.8159212e-01 9.5991987e-01 9.2027038e-01 9.7023094e-01\n",
      " 8.8855731e-01 9.6019071e-01 1.2105538e+00 4.7314534e-01 1.4621519e-07\n",
      " 0.0000000e+00 0.0000000e+00 1.8238679e-02 6.4063019e-01 9.5403302e-01\n",
      " 8.8345098e-01 9.6042144e-01 1.0148319e+00 1.0400944e+00 1.0327995e+00\n",
      " 1.0425024e+00 1.0377536e+00 1.0239737e+00 1.0123330e+00 1.0112470e+00\n",
      " 1.0080161e+00 1.0037187e+00 9.9538267e-01 9.8741001e-01 9.8437506e-01\n",
      " 9.9683672e-01 9.9902904e-01 1.0026282e+00 9.2901921e-01 9.9042606e-01\n",
      " 9.7649688e-01 1.2939514e+00 2.1699631e-07 0.0000000e+00 1.3923587e-07\n",
      " 1.1705050e-01 1.1144403e+00 1.1219565e+00 1.0156411e+00 1.0055025e+00\n",
      " 1.0032154e+00 1.0150743e+00 1.0142095e+00 1.0177913e+00 1.0253644e+00\n",
      " 1.0126159e+00 1.0002276e+00 9.9769658e-01 9.9821514e-01 9.9752837e-01\n",
      " 1.0002893e+00 9.8797697e-01 9.8150152e-01 9.8384297e-01 9.8327041e-01\n",
      " 9.7447395e-01 9.3285036e-01 1.0062920e+00 1.0228804e+00 9.3114692e-01\n",
      " 6.0625881e-01 9.7322207e-08 5.4108796e-07 6.1278325e-01 1.0809740e+00\n",
      " 1.0716383e+00 1.0269432e+00 9.9736565e-01 9.9421358e-01 1.0101815e+00\n",
      " 1.0143552e+00 1.0102618e+00 1.0035728e+00 1.0012336e+00 1.0068064e+00\n",
      " 9.9998099e-01 1.0015785e+00 9.9871236e-01 1.0003262e+00 9.9401909e-01\n",
      " 9.9334884e-01 9.8619485e-01 9.8499912e-01 9.6626389e-01 9.6145034e-01\n",
      " 9.7330666e-01 9.2773533e-01 8.1458670e-01 3.2373477e-02 1.9231571e-07\n",
      " 9.2407572e-01 1.2349374e+00 9.8670894e-01 9.7850680e-01 9.8879433e-01\n",
      " 9.9419016e-01 9.8188657e-01 1.0216916e+00 1.0213515e+00 1.0170696e+00\n",
      " 1.0121727e+00 1.0103637e+00 1.0066109e+00 1.0036218e+00 1.0004394e+00\n",
      " 1.0047470e+00 1.0006140e+00 9.9976277e-01 9.9997038e-01 9.9302065e-01\n",
      " 9.6662885e-01 9.6240062e-01 9.6749943e-01 9.9116701e-01 1.0045290e+00\n",
      " 1.1164415e+00 1.4924070e+00 2.7103467e+00 1.9174715e+00 1.3451148e+00\n",
      " 1.0265794e+00 9.7885364e-01 9.6735573e-01 9.6668452e-01 9.8186725e-01\n",
      " 1.0174313e+00 1.0245342e+00 1.0095873e+00 1.0098666e+00 9.9849361e-01\n",
      " 9.9785185e-01 1.0026022e+00 1.0009006e+00 1.0035757e+00 9.9672711e-01\n",
      " 1.0041075e+00 1.0015215e+00 9.8773885e-01 9.6467364e-01 9.6012163e-01\n",
      " 9.8062044e-01 9.6755785e-01 9.5196462e-01 1.0006055e+00 7.2126853e-01\n",
      " 3.1905856e+00 1.2599845e+00 9.0233284e-01 9.2299712e-01 9.0331191e-01\n",
      " 9.0232867e-01 9.4106817e-01 9.8069572e-01 1.0152668e+00 1.0258174e+00\n",
      " 1.0167036e+00 9.9983895e-01 9.9173474e-01 9.9820709e-01 1.0018262e+00\n",
      " 9.9564558e-01 9.9454623e-01 1.0030926e+00 1.0056641e+00 1.0005929e+00\n",
      " 9.7976387e-01 9.6089113e-01 9.6415073e-01 9.9407095e-01 9.5212066e-01\n",
      " 9.2987031e-01 1.3590082e+00 1.9902158e+00 7.1434253e-01 2.3389368e-01\n",
      " 7.7720976e-01 9.3217534e-01 8.0347645e-01 8.8489568e-01 9.2521417e-01\n",
      " 9.9453872e-01 1.0160288e+00 1.0285687e+00 1.0081717e+00 9.9491817e-01\n",
      " 9.8478585e-01 9.9377292e-01 9.9617654e-01 1.0060290e+00 1.0034151e+00\n",
      " 1.0052661e+00 1.0052780e+00 1.0010229e+00 9.9105483e-01 9.6934813e-01\n",
      " 9.8147309e-01 1.0200268e+00 1.0651495e+00 1.0581427e+00 1.1763512e+00\n",
      " 2.5984110e-07 1.4435500e-08 9.0504318e-01 9.5050532e-01 8.7664127e-01\n",
      " 8.0706799e-01 9.0298909e-01 9.7410768e-01 1.0173290e+00 1.0283030e+00\n",
      " 1.0157714e+00 1.0118625e+00 9.9707508e-01 9.8638266e-01 1.0008118e+00\n",
      " 1.0048327e+00 1.0092380e+00 1.0017540e+00 1.0080857e+00 9.9933672e-01\n",
      " 9.9881637e-01 9.9094051e-01 9.8523450e-01 1.0106570e+00 1.0249540e+00\n",
      " 1.0355624e+00 1.0197817e+00 5.1864988e-01 4.4703484e-08 9.7322207e-08\n",
      " 1.4493158e+00 9.1972619e-01 8.2681626e-01 8.4247661e-01 9.3865812e-01\n",
      " 9.8919445e-01 1.0190479e+00 1.0275589e+00 1.0058289e+00 1.0027103e+00\n",
      " 1.0038918e+00 1.0005037e+00 1.0056378e+00 1.0066073e+00 1.0076938e+00\n",
      " 9.9912965e-01 9.9762428e-01 1.0057149e+00 1.0076731e+00 1.0044971e+00\n",
      " 9.9908680e-01 1.0338198e+00 1.0175071e+00 9.7237480e-01 1.0830640e+00\n",
      " 1.1174977e+00 3.2596290e-09 9.7322207e-08 2.5798232e-07 5.0732064e-01\n",
      " 7.3046529e-01 9.1091830e-01 9.7230238e-01 9.9563217e-01 1.0061709e+00\n",
      " 1.0034721e+00 9.9966353e-01 1.0041084e+00 1.0060837e+00 9.9787831e-01\n",
      " 1.0058340e+00 1.0006223e+00 1.0029024e+00 9.9462098e-01 1.0013952e+00\n",
      " 1.0123702e+00 9.9947691e-01 1.0098333e+00 1.0171574e+00 1.0313308e+00\n",
      " 1.0366822e+00 9.4923753e-01 8.4834319e-01 3.6245281e-01 1.5879372e-07\n",
      " 9.3132257e-10 9.3132257e-09 6.2101704e-01 7.1983945e-01 9.4460630e-01\n",
      " 9.8393750e-01 9.8402941e-01 9.9349350e-01 1.0099164e+00 9.9792266e-01\n",
      " 1.0025674e+00 1.0015945e+00 1.0048823e+00 1.0071789e+00 1.0032666e+00\n",
      " 9.9865419e-01 9.9743408e-01 1.0057753e+00 1.0058385e+00 9.9566138e-01\n",
      " 1.0043901e+00 1.0170617e+00 1.0251448e+00 1.0297610e+00 9.6444690e-01\n",
      " 8.5974181e-01 4.9973881e-01 3.5017729e-07 9.7322207e-08 3.0641237e-07\n",
      " 1.0631645e+00 6.3315368e-01 9.4704551e-01 9.7759259e-01 9.7950596e-01\n",
      " 9.9370676e-01 1.0113342e+00 1.0054590e+00 1.0058295e+00 1.0008030e+00\n",
      " 1.0060775e+00 1.0019033e+00 9.9966353e-01 9.9341452e-01 1.0066601e+00\n",
      " 1.0111128e+00 1.0030842e+00 1.0105126e+00 1.0217290e+00 1.0127797e+00\n",
      " 9.9830294e-01 9.9251193e-01 9.3274713e-01 9.8436588e-01 6.3796514e-01\n",
      " 1.2180102e+00 0.0000000e+00 1.7415594e-07 9.9864435e-01 7.4397069e-01\n",
      " 9.6015650e-01 9.9942750e-01 9.9643946e-01 9.9913925e-01 1.0040557e+00\n",
      " 1.0170377e+00 1.0148900e+00 1.0014157e+00 1.0040817e+00 1.0083574e+00\n",
      " 1.0054440e+00 9.9641848e-01 1.0074476e+00 1.0063462e+00 1.0123031e+00\n",
      " 1.0293672e+00 1.0331516e+00 1.0169036e+00 1.0180503e+00 1.0015250e+00\n",
      " 9.5148140e-01 8.4654909e-01 4.2003885e-01 2.0356822e+00 2.0023239e-07\n",
      " 1.6297916e-07 1.0391560e+00 8.9351517e-01 9.8739928e-01 1.0013471e+00\n",
      " 1.0049818e+00 9.9958962e-01 1.0174489e+00 1.0169990e+00 1.0082986e+00\n",
      " 1.0018625e+00 1.0038110e+00 9.9881631e-01 9.9855900e-01 1.0022398e+00\n",
      " 1.0077553e+00 1.0045288e+00 1.0252571e+00 1.0371392e+00 1.0264668e+00\n",
      " 1.0054306e+00 9.8378843e-01 9.5740718e-01 8.8965821e-01 9.2027599e-01\n",
      " 1.1332550e+00 1.2792183e+00 9.6857548e-08 3.1850766e-07 7.4214190e-01\n",
      " 9.0674192e-01 9.8763019e-01 9.9015820e-01 1.0010471e+00 1.0111983e+00\n",
      " 1.0098671e+00 1.0045424e+00 1.0049363e+00 9.9989396e-01 9.9957389e-01\n",
      " 1.0034101e+00 1.0088469e+00 1.0028118e+00 1.0016403e+00 1.0100046e+00\n",
      " 1.0240328e+00 1.0334529e+00 1.0103495e+00 9.7453707e-01 9.4891161e-01\n",
      " 8.9989531e-01 8.6691415e-01 9.4533086e-01 1.0132674e+00 3.7138328e-01\n",
      " 0.0000000e+00 3.6879297e-07 7.4791992e-01 9.7683603e-01 9.4203979e-01\n",
      " 9.7781181e-01 9.8652631e-01 9.8968637e-01 1.0024040e+00 1.0043759e+00\n",
      " 1.0089241e+00 1.0060288e+00 1.0058736e+00 9.9735051e-01 1.0018818e+00\n",
      " 9.9741471e-01 1.0021681e+00 1.0169861e+00 1.0182829e+00 1.0089793e+00\n",
      " 9.8126775e-01 9.4765019e-01 9.2369431e-01 9.4647449e-01 9.1423219e-01\n",
      " 8.3075094e-01 9.2752743e-01 1.8719376e-07 9.7322207e-08 1.8998981e-07\n",
      " 9.7220099e-01 1.0356858e+00 9.3653607e-01 9.7871840e-01 9.8335409e-01\n",
      " 9.8107022e-01 9.9268007e-01 1.0018334e+00 1.0037558e+00 1.0010437e+00\n",
      " 1.0026059e+00 1.0023873e+00 9.9690431e-01 9.9614894e-01 1.0052675e+00\n",
      " 1.0159725e+00 1.0114850e+00 9.9292845e-01 9.5680875e-01 9.1042578e-01\n",
      " 8.9261472e-01 9.4262779e-01 9.1356033e-01 5.7761925e-01 1.0734981e+00\n",
      " 9.7322207e-08 9.6857548e-08 8.6610420e-08 4.2267248e-01 1.0085765e+00\n",
      " 9.1185874e-01 9.7253478e-01 9.6825254e-01 9.7799748e-01 9.8459119e-01\n",
      " 9.9474281e-01 1.0042838e+00 1.0026510e+00 1.0041611e+00 1.0042924e+00\n",
      " 1.0015846e+00 1.0012909e+00 1.0152729e+00 1.0144143e+00 1.0032771e+00\n",
      " 9.7161514e-01 9.1166180e-01 8.6349279e-01 8.6388534e-01 9.1234052e-01\n",
      " 1.0282832e+00 1.0408180e+00 2.0552683e-01 9.6857548e-08 0.0000000e+00\n",
      " 0.0000000e+00 2.2969939e-01 8.7106019e-01 8.4800917e-01 9.5905763e-01\n",
      " 9.5409286e-01 9.6251649e-01 9.6252888e-01 9.5984608e-01 9.7493786e-01\n",
      " 9.9090856e-01 9.9993831e-01 1.0012022e+00 1.0072546e+00 1.0190736e+00\n",
      " 1.0417241e+00 1.0406396e+00 1.0053297e+00 9.4491112e-01 8.8543069e-01\n",
      " 8.6751318e-01 8.5332257e-01 8.6996138e-01 9.2686403e-01 1.4954894e+00\n",
      " 1.7508864e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00 3.0081540e-07\n",
      " 5.2966851e-01 6.6614431e-01 8.6142522e-01 8.5901976e-01 8.8141215e-01\n",
      " 9.1684020e-01 9.1267139e-01 9.4329536e-01 9.7206110e-01 1.0004144e+00\n",
      " 1.0099839e+00 1.0359678e+00 1.0626701e+00 1.0976437e+00 1.0580245e+00\n",
      " 9.8248208e-01 9.3355733e-01 8.5958636e-01 7.3327971e-01 6.4231306e-01\n",
      " 5.7512474e-01 3.4111878e-01 3.2441661e-02 1.8766239e-07 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 2.3748726e-08 4.4516858e-07 1.0235959e+00\n",
      " 7.4200058e-01 6.6580683e-01 7.9511470e-01 8.6733103e-01 9.1189891e-01\n",
      " 9.3892008e-01 1.0127660e+00 1.0543573e+00 1.0565462e+00 1.0885599e+00\n",
      " 1.0790602e+00 1.1105967e+00 1.0336881e+00 8.7358195e-01 7.9610276e-01\n",
      " 8.3847237e-01 8.3425897e-01 3.5514408e-01 3.9943677e-01 9.7599399e-01\n",
      " 1.7106270e+00 9.7322207e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 9.7322207e-08 2.7383116e-01 8.0129248e-01 8.1639725e-01\n",
      " 8.0921692e-01 9.0992832e-01 9.1669482e-01 9.1026628e-01 9.9220777e-01\n",
      " 1.0592594e+00 1.0637164e+00 1.0289983e+00 1.0634072e+00 1.1291505e+00\n",
      " 9.7389418e-01 9.0384561e-01 8.7397712e-01 7.5759560e-01 1.1151954e+00\n",
      " 3.5504842e-01 1.9224802e+00 1.3503920e-07 9.7322207e-08 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 4.4703484e-08 3.0733645e-08 9.1169477e-02 9.1856825e-01 1.1865302e+00\n",
      " 8.1469429e-01 5.5920196e-01 6.2081337e-01 8.0629909e-01 9.5259869e-01\n",
      " 8.4624058e-01 1.0237356e+00 9.6282297e-01 8.8217759e-01 9.8680794e-01\n",
      " 1.1329248e+00 7.2375351e-01 3.6134253e-07 3.9208430e-07 1.2992055e-07\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]\n",
      "Train data shape:  (4000, 784)\n",
      "Train labels shape:  (4000, 10)\n",
      "Test data shape:  (6000, 784)\n",
      "Test labels shape:  (6000, 10)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from PIL import Image\n",
    "import tensorflow as tf\n",
    "\n",
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def preprocess_line(line):\n",
    "    # Split the line by commas to handle comma-separated values\n",
    "    preprocessed = line.strip().split(',')\n",
    "    return preprocessed\n",
    "\n",
    "def load_data(fashion=False, digit=None, normalize=False, one_hot=True):\n",
    "    if fashion:\n",
    "        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n",
    "    else:\n",
    "        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "    if digit is not None and 0 <= digit <= 9:\n",
    "        train = test = {y: [] for y in range(10)}\n",
    "        for x, y in zip(x_train, y_train):\n",
    "            train[y].append(x)\n",
    "        for x, y in zip(x_test, y_test):\n",
    "            test[y].append(x)\n",
    "\n",
    "        for y in range(10):\n",
    "            train[y] = np.asarray(train[y])\n",
    "            test[y] = np.asarray(test[y])\n",
    "\n",
    "        x_train = train[digit]\n",
    "        x_test = test[digit]\n",
    "\n",
    "    x_train = x_train.reshape((-1, x_train.shape[1] * x_train.shape[2])).astype(np.float32)\n",
    "    x_test = x_test.reshape((-1, x_test.shape[1] * x_test.shape[2])).astype(np.float32)\n",
    "\n",
    "    if normalize:\n",
    "        # normalize data\n",
    "        scaler = StandardScaler().fit(x_train)\n",
    "        x_train = scaler.transform(x_train)\n",
    "        x_test = scaler.transform(x_test)\n",
    "\n",
    "    num_classes = np.unique(y_train).shape[0]\n",
    "\n",
    "    if one_hot:\n",
    "        y_train = one_hot_encode(y_train, num_classes)\n",
    "        y_test = one_hot_encode(y_test, num_classes)\n",
    "\n",
    "    print(\"x_train shape: {}, y_train shape: {}\".format(x_train.shape, y_train.shape))\n",
    "    print(\"x_test shape: {}, y_test shape: {}\".format(x_test.shape, y_test.shape))\n",
    "\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "\n",
    "\n",
    "def load_mnist(one_hot=True):\n",
    "    train, test = load_data(fashion = False, normalize = True, one_hot=one_hot)\n",
    "    \n",
    "    # Set seed\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size = 0.6)\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_fashion(one_hot=True):\n",
    "    train, test = load_data(fashion = True, normalize = True, one_hot=one_hot)\n",
    "    \n",
    "    # Set seed\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size = 0.6)\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_mnist()\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-19T17:44:11.152499Z",
     "iopub.status.busy": "2025-01-19T17:44:11.151852Z",
     "iopub.status.idle": "2025-01-19T17:44:19.060037Z",
     "shell.execute_reply": "2025-01-19T17:44:19.059438Z",
     "shell.execute_reply.started": "2025-01-19T17:44:11.152474Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 784)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               235500    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 10)                1010      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 266,610\n",
      "Trainable params: 266,610\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 1.000e-01\n",
      "Epoch 1/100\n",
      "16/16 [==============================] - 1s 3ms/step - loss: 0.8305 - accuracy: 0.7577\n",
      "\n",
      "Epoch 2: Current learning rate = 9.997e-02\n",
      "Epoch 2/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.2112 - accuracy: 0.9358\n",
      "\n",
      "Epoch 3: Current learning rate = 9.989e-02\n",
      "Epoch 3/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9750\n",
      "\n",
      "Epoch 4: Current learning rate = 9.975e-02\n",
      "Epoch 4/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0300 - accuracy: 0.9935\n",
      "\n",
      "Epoch 5: Current learning rate = 9.955e-02\n",
      "Epoch 5/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0117 - accuracy: 0.9995\n",
      "\n",
      "Epoch 6: Current learning rate = 9.930e-02\n",
      "Epoch 6/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0064 - accuracy: 1.0000\n",
      "\n",
      "Epoch 7: Current learning rate = 9.899e-02\n",
      "Epoch 7/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0043 - accuracy: 1.0000\n",
      "\n",
      "Epoch 8: Current learning rate = 9.863e-02\n",
      "Epoch 8/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0032 - accuracy: 1.0000\n",
      "\n",
      "Epoch 9: Current learning rate = 9.821e-02\n",
      "Epoch 9/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0026 - accuracy: 1.0000\n",
      "\n",
      "Epoch 10: Current learning rate = 9.774e-02\n",
      "Epoch 10/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0022 - accuracy: 1.0000\n",
      "\n",
      "Epoch 11: Current learning rate = 9.722e-02\n",
      "Epoch 11/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0019 - accuracy: 1.0000\n",
      "\n",
      "Epoch 12: Current learning rate = 9.664e-02\n",
      "Epoch 12/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0017 - accuracy: 1.0000\n",
      "\n",
      "Epoch 13: Current learning rate = 9.601e-02\n",
      "Epoch 13/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0015 - accuracy: 1.0000\n",
      "\n",
      "Epoch 14: Current learning rate = 9.533e-02\n",
      "Epoch 14/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0014 - accuracy: 1.0000\n",
      "\n",
      "Epoch 15: Current learning rate = 9.460e-02\n",
      "Epoch 15/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0013 - accuracy: 1.0000\n",
      "\n",
      "Epoch 16: Current learning rate = 9.382e-02\n",
      "Epoch 16/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0012 - accuracy: 1.0000\n",
      "\n",
      "Epoch 17: Current learning rate = 9.298e-02\n",
      "Epoch 17/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0011 - accuracy: 1.0000\n",
      "\n",
      "Epoch 18: Current learning rate = 9.210e-02\n",
      "Epoch 18/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 0.0010 - accuracy: 1.0000\n",
      "\n",
      "Epoch 19: Current learning rate = 9.118e-02\n",
      "Epoch 19/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 9.6123e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 20: Current learning rate = 9.020e-02\n",
      "Epoch 20/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 9.0334e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 21: Current learning rate = 8.918e-02\n",
      "Epoch 21/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 8.5254e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 22: Current learning rate = 8.812e-02\n",
      "Epoch 22/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 8.1006e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 23: Current learning rate = 8.702e-02\n",
      "Epoch 23/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 7.7232e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 24: Current learning rate = 8.587e-02\n",
      "Epoch 24/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 7.3212e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 25: Current learning rate = 8.468e-02\n",
      "Epoch 25/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 6.9883e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 26: Current learning rate = 8.346e-02\n",
      "Epoch 26/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 6.6858e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 27: Current learning rate = 8.219e-02\n",
      "Epoch 27/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 6.4137e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 28: Current learning rate = 8.089e-02\n",
      "Epoch 28/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 6.1751e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 29: Current learning rate = 7.956e-02\n",
      "Epoch 29/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.9433e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 30: Current learning rate = 7.819e-02\n",
      "Epoch 30/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.7523e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 31: Current learning rate = 7.679e-02\n",
      "Epoch 31/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.5662e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 32: Current learning rate = 7.536e-02\n",
      "Epoch 32/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.4021e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 33: Current learning rate = 7.390e-02\n",
      "Epoch 33/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.2144e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 34: Current learning rate = 7.242e-02\n",
      "Epoch 34/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 5.0579e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 35: Current learning rate = 7.091e-02\n",
      "Epoch 35/100\n",
      "16/16 [==============================] - 0s 4ms/step - loss: 4.9237e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 36: Current learning rate = 6.938e-02\n",
      "Epoch 36/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.7918e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 37: Current learning rate = 6.782e-02\n",
      "Epoch 37/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.6744e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 38: Current learning rate = 6.625e-02\n",
      "Epoch 38/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.5601e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 39: Current learning rate = 6.465e-02\n",
      "Epoch 39/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.4528e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 40: Current learning rate = 6.304e-02\n",
      "Epoch 40/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 4.3587e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 41: Current learning rate = 6.142e-02\n",
      "Epoch 41/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 4.2627e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 42: Current learning rate = 5.978e-02\n",
      "Epoch 42/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.1785e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 43: Current learning rate = 5.813e-02\n",
      "Epoch 43/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 4.0952e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 44: Current learning rate = 5.647e-02\n",
      "Epoch 44/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 4.0228e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 45: Current learning rate = 5.481e-02\n",
      "Epoch 45/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 3.9463e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 46: Current learning rate = 5.314e-02\n",
      "Epoch 46/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.8833e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 47: Current learning rate = 5.147e-02\n",
      "Epoch 47/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.8232e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 48: Current learning rate = 4.979e-02\n",
      "Epoch 48/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.7635e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 49: Current learning rate = 4.812e-02\n",
      "Epoch 49/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.7125e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 50: Current learning rate = 4.644e-02\n",
      "Epoch 50/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.6551e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 51: Current learning rate = 4.477e-02\n",
      "Epoch 51/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.6069e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 52: Current learning rate = 4.311e-02\n",
      "Epoch 52/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.5604e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 53: Current learning rate = 4.146e-02\n",
      "Epoch 53/100\n",
      "16/16 [==============================] - 0s 4ms/step - loss: 3.5190e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 54: Current learning rate = 3.981e-02\n",
      "Epoch 54/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.4792e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 55: Current learning rate = 3.818e-02\n",
      "Epoch 55/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.4389e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 56: Current learning rate = 3.655e-02\n",
      "Epoch 56/100\n",
      "16/16 [==============================] - 0s 4ms/step - loss: 3.4030e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 57: Current learning rate = 3.495e-02\n",
      "Epoch 57/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.3707e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 58: Current learning rate = 3.336e-02\n",
      "Epoch 58/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.3391e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 59: Current learning rate = 3.179e-02\n",
      "Epoch 59/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.3107e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 60: Current learning rate = 3.024e-02\n",
      "Epoch 60/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.2811e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 61: Current learning rate = 2.871e-02\n",
      "Epoch 61/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.2554e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 62: Current learning rate = 2.721e-02\n",
      "Epoch 62/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.2306e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 63: Current learning rate = 2.573e-02\n",
      "Epoch 63/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.2079e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 64: Current learning rate = 2.428e-02\n",
      "Epoch 64/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 3.1863e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 65: Current learning rate = 2.286e-02\n",
      "Epoch 65/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 3.1667e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 66: Current learning rate = 2.146e-02\n",
      "Epoch 66/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.1503e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 67: Current learning rate = 2.010e-02\n",
      "Epoch 67/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.1316e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 68: Current learning rate = 1.878e-02\n",
      "Epoch 68/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.1169e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 69: Current learning rate = 1.749e-02\n",
      "Epoch 69/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.1013e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 70: Current learning rate = 1.623e-02\n",
      "Epoch 70/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0881e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 71: Current learning rate = 1.502e-02\n",
      "Epoch 71/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0747e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 72: Current learning rate = 1.384e-02\n",
      "Epoch 72/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0638e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 73: Current learning rate = 1.270e-02\n",
      "Epoch 73/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0534e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 74: Current learning rate = 1.161e-02\n",
      "Epoch 74/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0437e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 75: Current learning rate = 1.056e-02\n",
      "Epoch 75/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0351e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 76: Current learning rate = 9.549e-03\n",
      "Epoch 76/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0270e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 77: Current learning rate = 8.587e-03\n",
      "Epoch 77/100\n",
      "16/16 [==============================] - 0s 4ms/step - loss: 3.0199e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 78: Current learning rate = 7.672e-03\n",
      "Epoch 78/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0139e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 79: Current learning rate = 6.804e-03\n",
      "Epoch 79/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 3.0080e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 80: Current learning rate = 5.984e-03\n",
      "Epoch 80/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 3.0032e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 81: Current learning rate = 5.214e-03\n",
      "Epoch 81/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9988e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 82: Current learning rate = 4.495e-03\n",
      "Epoch 82/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9950e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 83: Current learning rate = 3.826e-03\n",
      "Epoch 83/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9920e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 84: Current learning rate = 3.209e-03\n",
      "Epoch 84/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9892e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 85: Current learning rate = 2.645e-03\n",
      "Epoch 85/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9867e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 86: Current learning rate = 2.134e-03\n",
      "Epoch 86/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9847e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 87: Current learning rate = 1.677e-03\n",
      "Epoch 87/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9833e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 88: Current learning rate = 1.274e-03\n",
      "Epoch 88/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9821e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 89: Current learning rate = 9.253e-04\n",
      "Epoch 89/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9811e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 90: Current learning rate = 6.321e-04\n",
      "Epoch 90/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9804e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 91: Current learning rate = 3.943e-04\n",
      "Epoch 91/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9800e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 92: Current learning rate = 2.122e-04\n",
      "Epoch 92/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9797e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 93: Current learning rate = 8.595e-05\n",
      "Epoch 93/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9795e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 94: Current learning rate = 1.579e-05\n",
      "Epoch 94/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 95: Current learning rate = 0.000e+00\n",
      "Epoch 95/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 96: Current learning rate = 0.000e+00\n",
      "Epoch 96/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.9794e-04 - accuracy: 1.0000\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.4268 - accuracy: 0.9338\n",
      "\n",
      "Test loss 0.4268372654914856\n",
      "Test accuracy 0.9338333606719971\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on mnist\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_mnist'\n",
    "#DEPTH = DEPTH\n",
    "LA = 0 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = 0.1 #INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "EPOCHS = EPOCHS\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T17:44:50.230780Z",
     "iopub.status.busy": "2025-01-19T17:44:50.230189Z",
     "iopub.status.idle": "2025-01-19T18:34:24.528135Z",
     "shell.execute_reply": "2025-01-19T18:34:24.527469Z",
     "shell.execute_reply.started": "2025-01-19T17:44:50.230755Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED\n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_mnist'  # Keep original dataset\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                              dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                              large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                   loss='categorical_crossentropy',\n",
    "                   metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                   callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                            'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T10:37:36.427069Z",
     "iopub.status.busy": "2025-01-20T10:37:36.426809Z",
     "iopub.status.idle": "2025-01-20T10:37:37.382132Z",
     "shell.execute_reply": "2025-01-20T10:37:37.381249Z",
     "shell.execute_reply.started": "2025-01-20T10:37:36.427052Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Normalized Training Set Mean and SD: [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -4.41811746e-03 -5.75460540e-03 -4.08261409e-03 -4.08261409e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -4.08261409e-03 -4.70944401e-03 -8.79918039e-03 -1.15902880e-02\n",
      " -4.10827546e-04  7.34113948e-03  1.62988584e-02  4.16030996e-02\n",
      "  2.12192275e-02 -2.18368624e-03  1.40861133e-02  2.34294031e-02\n",
      " -1.62362196e-02 -6.72437530e-03  1.82735343e-02  1.46022588e-02\n",
      " -1.81332615e-03 -1.09966453e-02 -8.32507201e-03 -4.38052649e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -4.08261409e-03 -5.39535051e-03\n",
      " -8.52213986e-03  1.62882105e-01  2.82596853e-02 -6.80418860e-04\n",
      " -2.75809946e-03  2.19389424e-02  5.55645563e-02  9.38820988e-02\n",
      "  7.39584491e-02  3.26009393e-02  1.77913085e-02  1.09734777e-02\n",
      " -5.40107302e-03 -6.18119864e-03 -1.37357162e-02 -2.40913182e-02\n",
      " -1.69764850e-02 -4.40417323e-03  1.58787761e-02  1.13025401e-02\n",
      " -1.35878371e-02 -7.83111621e-03  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -5.36850980e-03 -8.87078419e-03\n",
      "  4.95642005e-03  1.49929868e-02  3.77630778e-02  1.74969789e-02\n",
      "  3.17634605e-02  4.31552008e-02  5.79657406e-02  6.51872158e-02\n",
      "  6.95984289e-02  5.02408594e-02  1.19145364e-02 -4.23975894e-03\n",
      " -1.69138461e-02 -3.21586989e-02 -3.33180428e-02 -2.81243585e-02\n",
      " -2.04585288e-02 -2.81097014e-02 -5.48643339e-03 -1.78614762e-04\n",
      " -2.01879907e-03  6.57787779e-03 -5.28352521e-03  0.00000000e+00\n",
      "  0.00000000e+00 -4.08261409e-03 -7.76327914e-03  1.49458200e-02\n",
      " -4.18766495e-03 -2.23613940e-02  1.07237101e-02  1.90405026e-02\n",
      "  1.96592584e-02  3.35484110e-02  4.25123088e-02  5.26890568e-02\n",
      "  4.26664837e-02  3.57930884e-02  2.34533530e-02  1.10070463e-02\n",
      " -4.27128980e-04 -1.77778229e-02 -2.88259368e-02 -1.89742204e-02\n",
      " -2.28149947e-02 -2.75410265e-02 -1.10654766e-02 -1.60864610e-02\n",
      " -1.03070363e-02  1.73480459e-03 -4.03776811e-03 -5.77035174e-03\n",
      "  0.00000000e+00  0.00000000e+00 -1.16134398e-02 -1.68821972e-03\n",
      "  1.11869664e-03 -1.14550609e-02 -4.49247978e-04  9.97761637e-03\n",
      "  1.84000488e-02  2.26127096e-02  4.90894876e-02  5.98350950e-02\n",
      "  4.87283096e-02  3.02107446e-02  1.80491935e-02  3.70267918e-03\n",
      " -5.76877641e-03 -1.11815091e-02 -2.53235511e-02 -1.95845552e-02\n",
      " -7.30251940e-03 -9.80170164e-03 -7.79611943e-03 -2.34153904e-02\n",
      " -1.20436801e-02 -6.06956286e-03  1.67097524e-02 -8.20913538e-03\n",
      "  0.00000000e+00 -5.57001401e-03 -1.38114160e-02  6.78550918e-03\n",
      "  1.81221552e-02  1.24981825e-03 -2.06548866e-04  5.56156458e-03\n",
      "  1.55270230e-02  2.74176374e-02  3.63383070e-02  5.49860373e-02\n",
      "  5.38944043e-02  3.41811590e-02  3.26609612e-02  1.87016446e-02\n",
      " -2.26639491e-03 -1.42939240e-02 -3.15517560e-02 -3.14056985e-02\n",
      " -2.26257741e-02 -1.79869793e-02 -1.90046374e-02 -2.77752709e-02\n",
      " -8.34583584e-03 -1.02854380e-03 -4.40107239e-03 -5.12118824e-03\n",
      " -4.08261409e-03 -1.32304691e-02 -6.43922621e-03  9.91706457e-03\n",
      "  1.55324936e-02  1.02405651e-02 -3.56938876e-03 -2.94284383e-03\n",
      "  2.36373395e-02  3.48850563e-02  4.46509011e-02  6.21613935e-02\n",
      "  5.10751903e-02  3.94567326e-02  3.63223776e-02  2.79514063e-02\n",
      "  1.21207046e-03 -9.31385811e-03 -2.11941134e-02 -2.53404211e-02\n",
      " -3.30689959e-02 -2.90159546e-02 -3.26997004e-02 -2.45186239e-02\n",
      " -1.41697209e-02 -1.26495296e-02 -1.36567568e-02 -1.46672800e-02\n",
      " -5.02568530e-03  1.00326436e-02  2.42148079e-02 -5.89473685e-03\n",
      " -9.81412828e-04 -3.77764739e-03 -8.52157455e-03 -9.43157263e-03\n",
      "  3.33726108e-02  5.82444258e-02  8.52125958e-02  8.43273252e-02\n",
      "  4.01754603e-02  2.14373358e-02  2.05844082e-02  1.21960789e-02\n",
      "  2.08280236e-03  1.01914816e-03 -1.21372687e-02 -1.45974560e-02\n",
      " -2.44250670e-02 -4.13784347e-02 -3.92613634e-02 -2.44028643e-02\n",
      " -7.86530226e-03 -2.80681998e-03  7.91423954e-03  1.00944173e-02\n",
      "  4.23314497e-02  5.08148149e-02  2.46029440e-02  3.32195847e-03\n",
      " -5.67673612e-03 -1.86863318e-02 -2.41163746e-02 -1.00210020e-02\n",
      "  3.97832133e-02  8.28146338e-02  9.68241468e-02  6.20155148e-02\n",
      "  1.35290902e-02 -1.13764154e-02  4.77549474e-04  6.96658157e-03\n",
      " -1.79711671e-03  5.28488867e-03  2.96694483e-03 -1.09362323e-02\n",
      " -2.19472349e-02 -4.35140692e-02 -4.06785421e-02 -2.23062448e-02\n",
      " -8.44532717e-03 -1.22033171e-02  7.93837418e-04  2.64569279e-03\n",
      "  5.54948375e-02  1.75499450e-02 -7.65063334e-03 -9.57124960e-03\n",
      " -2.32865755e-02 -4.13067453e-02 -3.52517255e-02 -9.08283144e-03\n",
      "  5.00219166e-02  8.82577300e-02  7.83220828e-02  2.71792561e-02\n",
      " -3.13579589e-02 -2.79204641e-02 -3.52647598e-03 -7.16581801e-03\n",
      " -1.06959688e-02  2.03768481e-02  1.82100888e-02  2.20774044e-03\n",
      " -2.87568700e-02 -4.71206009e-02 -3.30201946e-02 -1.18533233e-02\n",
      " -5.77503350e-03 -9.28007439e-03  2.86790095e-02  2.99599227e-02\n",
      "  4.24134592e-03 -1.60132423e-02 -2.00060457e-02 -1.68357547e-02\n",
      " -4.34514955e-02 -4.75979373e-02 -4.17396389e-02  2.22611730e-03\n",
      "  5.65313660e-02  7.84565955e-02  5.14779203e-02 -1.10145723e-02\n",
      " -4.89527397e-02 -1.95959955e-02 -6.03359099e-03  8.12703744e-03\n",
      "  2.26473212e-02  3.49758156e-02  2.34207418e-02  5.19333920e-03\n",
      " -2.11802181e-02 -3.34797017e-02 -2.04111207e-02 -2.00267369e-03\n",
      "  1.11470614e-02  8.99341982e-03  1.19247911e-02 -9.12484061e-03\n",
      " -5.30433701e-03 -1.92702457e-03 -8.15872382e-03 -1.98461860e-02\n",
      " -3.42833698e-02 -3.56188677e-02 -1.28956847e-02  2.85033397e-02\n",
      "  5.72260730e-02  5.27206287e-02  2.25755665e-02 -2.11761035e-02\n",
      " -2.67201364e-02 -7.39592686e-03  1.71533320e-02  3.55378836e-02\n",
      "  4.77375649e-02  3.55996676e-02  2.30031516e-02 -4.26578277e-04\n",
      " -1.03206690e-02 -1.78005826e-02  2.36024684e-03  9.68122855e-03\n",
      "  9.91401076e-03  1.23962658e-02 -7.42188562e-03 -9.92202573e-03\n",
      " -4.08261409e-03  1.19835259e-02 -2.57651857e-03 -2.05178726e-02\n",
      " -2.91534010e-02 -2.35685986e-02 -8.04510026e-04  3.17953601e-02\n",
      "  4.23842520e-02  2.30016951e-02 -3.44195403e-04 -1.59274656e-02\n",
      " -2.66919974e-02  2.50606332e-03  3.53288651e-02  5.48702404e-02\n",
      "  4.80154566e-02  3.17603536e-02  1.46136926e-02  7.22691091e-03\n",
      "  1.46815320e-03  9.95976734e-04  2.21733991e-02  1.00731356e-02\n",
      " -4.49132611e-04  5.13953215e-04 -5.69798611e-03 -7.40432693e-03\n",
      " -4.08261409e-03 -7.98808783e-03 -1.02737295e-02 -2.47155316e-02\n",
      " -2.11008321e-02 -1.12750102e-02 -5.36655961e-03  1.00128651e-02\n",
      "  2.03726329e-02  1.57157816e-02 -2.14449456e-03 -2.62317597e-03\n",
      " -1.22394068e-02  1.02235246e-02  2.15032194e-02  4.45979759e-02\n",
      "  5.13432585e-02  3.08734756e-02  2.17046216e-02  8.45012721e-03\n",
      "  5.36837790e-04  1.40687767e-02  1.85617134e-02  1.49831427e-02\n",
      " -6.24519773e-03 -1.11615434e-02 -1.89122744e-02 -6.98474329e-03\n",
      " -4.77027940e-03 -4.18943726e-03 -8.82868096e-03 -2.19771266e-02\n",
      " -1.61028132e-02 -1.02803065e-02 -1.88928694e-02 -9.32159182e-03\n",
      "  7.68913375e-03  7.92805478e-03  3.35091818e-03  4.89460072e-03\n",
      "  1.41604273e-02  3.22548747e-02  2.37208586e-02  4.35541496e-02\n",
      "  5.52673154e-02  4.82611395e-02  3.02265361e-02  8.08343291e-03\n",
      "  4.17829351e-03  1.73312202e-02  1.30356066e-02  1.29084894e-02\n",
      " -9.83185321e-03 -1.40072182e-02 -2.05306690e-02 -9.10989009e-03\n",
      " -4.08261409e-03 -6.70758402e-03  7.74516317e-04 -2.92797517e-02\n",
      " -1.05634918e-02 -1.29393917e-02 -2.20253803e-02 -1.02981823e-02\n",
      "  1.54633205e-02  1.07212923e-02  3.82813648e-03 -5.31125255e-03\n",
      "  1.65223517e-02  2.07003020e-02  1.05960611e-02  2.88257264e-02\n",
      "  5.85474111e-02  7.48338476e-02  4.65679504e-02  3.13898921e-02\n",
      "  2.67218444e-02  1.58677213e-02  5.66643663e-03 -2.34023505e-03\n",
      " -1.88134275e-02 -5.75260399e-03 -1.68579053e-02  1.03213442e-02\n",
      "  0.00000000e+00 -8.93191434e-03 -3.28481954e-04 -3.06835510e-02\n",
      " -8.66520219e-03  8.77050741e-04 -8.49256199e-03 -1.22461831e-02\n",
      "  1.70802735e-02  2.78926343e-02  2.18792725e-02  2.88548507e-03\n",
      "  2.27689510e-03 -8.12013913e-03 -1.34364022e-02  2.22838409e-02\n",
      "  6.13453612e-02  9.44503173e-02  8.24194998e-02  6.45607412e-02\n",
      "  5.15957437e-02  1.90572832e-02  9.74230189e-03 -2.09910795e-03\n",
      " -1.72413383e-02 -2.04935037e-02 -2.48995908e-02  2.51586698e-02\n",
      " -5.90576651e-03 -6.35195151e-03 -9.69331420e-04 -1.97413545e-02\n",
      " -2.72226613e-03  8.89686926e-04  4.31215204e-03 -3.30640818e-03\n",
      "  2.09884215e-02  2.02501602e-02  1.17556509e-02 -2.22103714e-04\n",
      " -8.37222207e-03 -2.11433414e-02 -1.23987151e-02  3.04329880e-02\n",
      "  7.35105127e-02  9.69756842e-02  8.75395536e-02  8.14764425e-02\n",
      "  4.83968481e-02  1.06824087e-02 -7.33584771e-03 -1.82500910e-02\n",
      " -2.49640122e-02 -1.63379200e-02  1.85905793e-03  1.04673030e-02\n",
      " -4.08261409e-03 -1.12813441e-02 -1.58734061e-02 -1.51830455e-02\n",
      " -1.26617122e-02 -2.98885349e-03  2.13782117e-03  2.01105652e-03\n",
      "  7.97699764e-03  6.29997300e-03  1.13672409e-02  7.83106964e-03\n",
      " -8.15426093e-03 -1.08835557e-02  1.70712508e-02  5.21537550e-02\n",
      "  8.94790068e-02  8.42244551e-02  8.56550932e-02  7.07507357e-02\n",
      "  2.11949311e-02 -1.72263607e-02 -2.65874956e-02 -3.37393843e-02\n",
      " -3.11973970e-02 -1.12663805e-02  6.06848858e-04 -2.15330860e-03\n",
      "  0.00000000e+00 -1.03417365e-02 -1.22197848e-02 -4.27834317e-03\n",
      " -2.16761585e-02 -1.19962459e-02 -1.47948368e-02 -1.12066008e-02\n",
      " -2.72350921e-03 -2.51717190e-03  1.58068519e-02  2.42214389e-02\n",
      "  8.03474430e-03  1.75393503e-02  4.30478565e-02  6.64937571e-02\n",
      "  7.63814449e-02  7.55630657e-02  6.52094632e-02  3.21744233e-02\n",
      " -1.18789412e-02 -3.32951173e-02 -3.46551649e-02 -2.50138585e-02\n",
      " -1.75183099e-02 -1.43489260e-02  4.79382416e-03 -5.77350287e-03\n",
      " -4.08261409e-03 -1.00096930e-02 -1.31540536e-03 -1.46820064e-04\n",
      " -2.28588618e-02 -1.04693556e-02 -1.81319062e-02 -2.35702917e-02\n",
      " -1.85565315e-02 -1.43159144e-02 -6.61694724e-03  4.31766361e-03\n",
      "  9.35346540e-03  3.09117306e-02  5.17731197e-02  5.91272153e-02\n",
      "  6.89778849e-02  6.08396009e-02  3.11856512e-02 -8.34356248e-03\n",
      " -3.37222554e-02 -4.85104211e-02 -3.99011672e-02 -1.69564448e-02\n",
      " -1.30429976e-02 -2.70048063e-02  9.24569438e-04 -4.08261409e-03\n",
      " -4.08261409e-03 -6.00481918e-03 -2.26263311e-02 -2.95542250e-03\n",
      " -2.29217056e-02 -1.13520082e-02 -2.56241206e-02 -2.97484025e-02\n",
      " -3.51864472e-02 -2.72086877e-02 -2.58972049e-02 -9.62943584e-03\n",
      "  4.09017457e-03  2.02904511e-02  4.60916944e-02  4.98231128e-02\n",
      "  5.59843779e-02  2.98551247e-02  1.01009458e-02 -1.98486410e-02\n",
      " -4.24833633e-02 -5.44052608e-02 -3.72422449e-02 -1.98121760e-02\n",
      "  1.24626688e-03  2.08764267e-03 -8.95156339e-03 -4.08261409e-03\n",
      "  0.00000000e+00  0.00000000e+00 -1.91064626e-02 -1.35435145e-02\n",
      " -2.16376688e-02 -1.03856167e-02 -2.45866496e-02 -2.94696223e-02\n",
      " -4.15393040e-02 -6.16575517e-02 -4.36944328e-02 -2.20641773e-02\n",
      "  7.69260200e-03  2.48496085e-02  4.90829125e-02  5.72306849e-02\n",
      "  6.01896644e-02  3.96406464e-02  9.39116348e-03 -2.73344070e-02\n",
      " -4.55973335e-02 -3.86242680e-02 -2.85568126e-02 -1.53164165e-02\n",
      "  4.31055715e-03  2.92128008e-02 -1.04026571e-02  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -1.03530018e-02 -1.68102235e-02\n",
      " -2.24907678e-02 -2.10993234e-02 -3.08284760e-02 -4.58658747e-02\n",
      " -4.31182496e-02 -6.45727068e-02 -5.30257560e-02 -3.24722826e-02\n",
      "  2.10235245e-03  2.13605110e-02  4.91039939e-02  7.43381009e-02\n",
      "  8.82792696e-02  4.60067727e-02 -4.17719921e-03 -2.29650829e-02\n",
      " -3.25202122e-02 -4.37310264e-02 -3.97582501e-02 -3.26538906e-02\n",
      " -1.98085010e-02 -1.08662900e-02 -4.75306250e-03  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00 -6.66427705e-03 -1.09637678e-02\n",
      " -8.71103373e-04 -2.75413934e-02 -4.61031832e-02 -4.74594682e-02\n",
      " -3.83740030e-02 -3.54633592e-02 -2.97762491e-02  5.76993497e-03\n",
      "  3.57955880e-02  3.72918956e-02  5.51128723e-02  5.37002236e-02\n",
      "  5.57434000e-02  2.00574957e-02 -3.23551223e-02 -4.29426692e-02\n",
      " -2.72130780e-02 -2.38635503e-02 -3.41538191e-02 -2.24017650e-02\n",
      " -2.59548135e-04  2.13750787e-02 -4.08261409e-03  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00 -4.08261409e-03\n",
      " -1.01094292e-02 -1.03547452e-02 -1.63998064e-02 -2.63651311e-02\n",
      " -2.02881917e-02 -2.40834001e-02 -2.56966818e-02  2.03555543e-03\n",
      "  2.51829512e-02  2.63418723e-02  2.14687567e-02  2.21937317e-02\n",
      "  3.37087959e-02  6.75980421e-03 -1.50866285e-02 -1.35353142e-02\n",
      " -1.64796859e-02  1.23606250e-03 -1.14327017e-02  1.73845831e-02\n",
      " -4.49435040e-03 -4.08261409e-03  0.00000000e+00  0.00000000e+00\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      " -5.79020847e-03 -9.05387662e-03 -1.45641081e-02 -3.70762567e-03\n",
      " -3.07631126e-04 -1.08020976e-02 -3.09229475e-02 -2.82016695e-02\n",
      " -2.04058159e-02 -8.88458174e-03 -1.42098330e-02 -2.73996079e-03\n",
      " -5.47817070e-03 -1.14821615e-02  1.76217553e-04  1.61317724e-03\n",
      " -4.62485012e-03 -1.14325648e-02 -9.00690258e-03 -5.77026419e-03\n",
      "  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00] [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 3.9580062e-08 2.1467318e-07 9.7322207e-08\n",
      " 9.6857548e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 9.7322207e-08 2.4400114e-07 1.6577050e-07\n",
      " 2.7101578e-07 9.0735304e-01 1.0841730e+00 1.2711843e+00 1.5736740e+00\n",
      " 1.2879838e+00 9.6679288e-01 1.1832184e+00 1.4103526e+00 5.5485547e-01\n",
      " 8.7634039e-01 1.4318393e+00 1.4727974e+00 7.9626387e-01 2.8497971e-07\n",
      " 2.1141406e-07 1.6717675e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 9.7322207e-08 4.6566129e-09\n",
      " 2.7194076e-07 7.9728589e+00 1.9015563e+00 1.0075092e+00 8.3693093e-01\n",
      " 1.1488374e+00 1.3370754e+00 1.5042129e+00 1.3405477e+00 1.0911613e+00\n",
      " 1.0712247e+00 1.0505410e+00 9.6034199e-01 9.8247695e-01 9.4710863e-01\n",
      " 8.7330800e-01 8.5679340e-01 9.9351680e-01 1.2876799e+00 1.2477993e+00\n",
      " 3.1622518e-02 2.7567148e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 1.3224560e-07 1.7881393e-07 1.2033213e+00 1.3229245e+00\n",
      " 1.5235522e+00 1.1076982e+00 1.1632768e+00 1.1813263e+00 1.2103313e+00\n",
      " 1.1834729e+00 1.1790432e+00 1.1173617e+00 1.0291021e+00 9.9540633e-01\n",
      " 9.7862196e-01 9.3266672e-01 9.2032367e-01 9.3253803e-01 9.3327069e-01\n",
      " 8.5878128e-01 9.7875196e-01 9.7926056e-01 1.0608642e+00 9.9633288e-01\n",
      " 1.9418022e-07 0.0000000e+00 0.0000000e+00 9.7322207e-08 1.4342014e-07\n",
      " 1.8705932e+00 9.9057370e-01 6.7717135e-01 1.0727829e+00 1.1030509e+00\n",
      " 1.0745533e+00 1.0588750e+00 1.0800353e+00 1.0832589e+00 1.0565811e+00\n",
      " 1.0338585e+00 1.0175095e+00 1.0101744e+00 9.9652171e-01 9.7620130e-01\n",
      " 9.6532804e-01 9.8159212e-01 9.5991987e-01 9.2027038e-01 9.7023094e-01\n",
      " 8.8855731e-01 9.6019071e-01 1.2105538e+00 4.7314534e-01 1.4621519e-07\n",
      " 0.0000000e+00 0.0000000e+00 1.8238679e-02 6.4063019e-01 9.5403302e-01\n",
      " 8.8345098e-01 9.6042144e-01 1.0148319e+00 1.0400944e+00 1.0327995e+00\n",
      " 1.0425024e+00 1.0377536e+00 1.0239737e+00 1.0123330e+00 1.0112470e+00\n",
      " 1.0080161e+00 1.0037187e+00 9.9538267e-01 9.8741001e-01 9.8437506e-01\n",
      " 9.9683672e-01 9.9902904e-01 1.0026282e+00 9.2901921e-01 9.9042606e-01\n",
      " 9.7649688e-01 1.2939514e+00 2.1699631e-07 0.0000000e+00 1.3923587e-07\n",
      " 1.1705050e-01 1.1144403e+00 1.1219565e+00 1.0156411e+00 1.0055025e+00\n",
      " 1.0032154e+00 1.0150743e+00 1.0142095e+00 1.0177913e+00 1.0253644e+00\n",
      " 1.0126159e+00 1.0002276e+00 9.9769658e-01 9.9821514e-01 9.9752837e-01\n",
      " 1.0002893e+00 9.8797697e-01 9.8150152e-01 9.8384297e-01 9.8327041e-01\n",
      " 9.7447395e-01 9.3285036e-01 1.0062920e+00 1.0228804e+00 9.3114692e-01\n",
      " 6.0625881e-01 9.7322207e-08 5.4108796e-07 6.1278325e-01 1.0809740e+00\n",
      " 1.0716383e+00 1.0269432e+00 9.9736565e-01 9.9421358e-01 1.0101815e+00\n",
      " 1.0143552e+00 1.0102618e+00 1.0035728e+00 1.0012336e+00 1.0068064e+00\n",
      " 9.9998099e-01 1.0015785e+00 9.9871236e-01 1.0003262e+00 9.9401909e-01\n",
      " 9.9334884e-01 9.8619485e-01 9.8499912e-01 9.6626389e-01 9.6145034e-01\n",
      " 9.7330666e-01 9.2773533e-01 8.1458670e-01 3.2373477e-02 1.9231571e-07\n",
      " 9.2407572e-01 1.2349374e+00 9.8670894e-01 9.7850680e-01 9.8879433e-01\n",
      " 9.9419016e-01 9.8188657e-01 1.0216916e+00 1.0213515e+00 1.0170696e+00\n",
      " 1.0121727e+00 1.0103637e+00 1.0066109e+00 1.0036218e+00 1.0004394e+00\n",
      " 1.0047470e+00 1.0006140e+00 9.9976277e-01 9.9997038e-01 9.9302065e-01\n",
      " 9.6662885e-01 9.6240062e-01 9.6749943e-01 9.9116701e-01 1.0045290e+00\n",
      " 1.1164415e+00 1.4924070e+00 2.7103467e+00 1.9174715e+00 1.3451148e+00\n",
      " 1.0265794e+00 9.7885364e-01 9.6735573e-01 9.6668452e-01 9.8186725e-01\n",
      " 1.0174313e+00 1.0245342e+00 1.0095873e+00 1.0098666e+00 9.9849361e-01\n",
      " 9.9785185e-01 1.0026022e+00 1.0009006e+00 1.0035757e+00 9.9672711e-01\n",
      " 1.0041075e+00 1.0015215e+00 9.8773885e-01 9.6467364e-01 9.6012163e-01\n",
      " 9.8062044e-01 9.6755785e-01 9.5196462e-01 1.0006055e+00 7.2126853e-01\n",
      " 3.1905856e+00 1.2599845e+00 9.0233284e-01 9.2299712e-01 9.0331191e-01\n",
      " 9.0232867e-01 9.4106817e-01 9.8069572e-01 1.0152668e+00 1.0258174e+00\n",
      " 1.0167036e+00 9.9983895e-01 9.9173474e-01 9.9820709e-01 1.0018262e+00\n",
      " 9.9564558e-01 9.9454623e-01 1.0030926e+00 1.0056641e+00 1.0005929e+00\n",
      " 9.7976387e-01 9.6089113e-01 9.6415073e-01 9.9407095e-01 9.5212066e-01\n",
      " 9.2987031e-01 1.3590082e+00 1.9902158e+00 7.1434253e-01 2.3389368e-01\n",
      " 7.7720976e-01 9.3217534e-01 8.0347645e-01 8.8489568e-01 9.2521417e-01\n",
      " 9.9453872e-01 1.0160288e+00 1.0285687e+00 1.0081717e+00 9.9491817e-01\n",
      " 9.8478585e-01 9.9377292e-01 9.9617654e-01 1.0060290e+00 1.0034151e+00\n",
      " 1.0052661e+00 1.0052780e+00 1.0010229e+00 9.9105483e-01 9.6934813e-01\n",
      " 9.8147309e-01 1.0200268e+00 1.0651495e+00 1.0581427e+00 1.1763512e+00\n",
      " 2.5984110e-07 1.4435500e-08 9.0504318e-01 9.5050532e-01 8.7664127e-01\n",
      " 8.0706799e-01 9.0298909e-01 9.7410768e-01 1.0173290e+00 1.0283030e+00\n",
      " 1.0157714e+00 1.0118625e+00 9.9707508e-01 9.8638266e-01 1.0008118e+00\n",
      " 1.0048327e+00 1.0092380e+00 1.0017540e+00 1.0080857e+00 9.9933672e-01\n",
      " 9.9881637e-01 9.9094051e-01 9.8523450e-01 1.0106570e+00 1.0249540e+00\n",
      " 1.0355624e+00 1.0197817e+00 5.1864988e-01 4.4703484e-08 9.7322207e-08\n",
      " 1.4493158e+00 9.1972619e-01 8.2681626e-01 8.4247661e-01 9.3865812e-01\n",
      " 9.8919445e-01 1.0190479e+00 1.0275589e+00 1.0058289e+00 1.0027103e+00\n",
      " 1.0038918e+00 1.0005037e+00 1.0056378e+00 1.0066073e+00 1.0076938e+00\n",
      " 9.9912965e-01 9.9762428e-01 1.0057149e+00 1.0076731e+00 1.0044971e+00\n",
      " 9.9908680e-01 1.0338198e+00 1.0175071e+00 9.7237480e-01 1.0830640e+00\n",
      " 1.1174977e+00 3.2596290e-09 9.7322207e-08 2.5798232e-07 5.0732064e-01\n",
      " 7.3046529e-01 9.1091830e-01 9.7230238e-01 9.9563217e-01 1.0061709e+00\n",
      " 1.0034721e+00 9.9966353e-01 1.0041084e+00 1.0060837e+00 9.9787831e-01\n",
      " 1.0058340e+00 1.0006223e+00 1.0029024e+00 9.9462098e-01 1.0013952e+00\n",
      " 1.0123702e+00 9.9947691e-01 1.0098333e+00 1.0171574e+00 1.0313308e+00\n",
      " 1.0366822e+00 9.4923753e-01 8.4834319e-01 3.6245281e-01 1.5879372e-07\n",
      " 9.3132257e-10 9.3132257e-09 6.2101704e-01 7.1983945e-01 9.4460630e-01\n",
      " 9.8393750e-01 9.8402941e-01 9.9349350e-01 1.0099164e+00 9.9792266e-01\n",
      " 1.0025674e+00 1.0015945e+00 1.0048823e+00 1.0071789e+00 1.0032666e+00\n",
      " 9.9865419e-01 9.9743408e-01 1.0057753e+00 1.0058385e+00 9.9566138e-01\n",
      " 1.0043901e+00 1.0170617e+00 1.0251448e+00 1.0297610e+00 9.6444690e-01\n",
      " 8.5974181e-01 4.9973881e-01 3.5017729e-07 9.7322207e-08 3.0641237e-07\n",
      " 1.0631645e+00 6.3315368e-01 9.4704551e-01 9.7759259e-01 9.7950596e-01\n",
      " 9.9370676e-01 1.0113342e+00 1.0054590e+00 1.0058295e+00 1.0008030e+00\n",
      " 1.0060775e+00 1.0019033e+00 9.9966353e-01 9.9341452e-01 1.0066601e+00\n",
      " 1.0111128e+00 1.0030842e+00 1.0105126e+00 1.0217290e+00 1.0127797e+00\n",
      " 9.9830294e-01 9.9251193e-01 9.3274713e-01 9.8436588e-01 6.3796514e-01\n",
      " 1.2180102e+00 0.0000000e+00 1.7415594e-07 9.9864435e-01 7.4397069e-01\n",
      " 9.6015650e-01 9.9942750e-01 9.9643946e-01 9.9913925e-01 1.0040557e+00\n",
      " 1.0170377e+00 1.0148900e+00 1.0014157e+00 1.0040817e+00 1.0083574e+00\n",
      " 1.0054440e+00 9.9641848e-01 1.0074476e+00 1.0063462e+00 1.0123031e+00\n",
      " 1.0293672e+00 1.0331516e+00 1.0169036e+00 1.0180503e+00 1.0015250e+00\n",
      " 9.5148140e-01 8.4654909e-01 4.2003885e-01 2.0356822e+00 2.0023239e-07\n",
      " 1.6297916e-07 1.0391560e+00 8.9351517e-01 9.8739928e-01 1.0013471e+00\n",
      " 1.0049818e+00 9.9958962e-01 1.0174489e+00 1.0169990e+00 1.0082986e+00\n",
      " 1.0018625e+00 1.0038110e+00 9.9881631e-01 9.9855900e-01 1.0022398e+00\n",
      " 1.0077553e+00 1.0045288e+00 1.0252571e+00 1.0371392e+00 1.0264668e+00\n",
      " 1.0054306e+00 9.8378843e-01 9.5740718e-01 8.8965821e-01 9.2027599e-01\n",
      " 1.1332550e+00 1.2792183e+00 9.6857548e-08 3.1850766e-07 7.4214190e-01\n",
      " 9.0674192e-01 9.8763019e-01 9.9015820e-01 1.0010471e+00 1.0111983e+00\n",
      " 1.0098671e+00 1.0045424e+00 1.0049363e+00 9.9989396e-01 9.9957389e-01\n",
      " 1.0034101e+00 1.0088469e+00 1.0028118e+00 1.0016403e+00 1.0100046e+00\n",
      " 1.0240328e+00 1.0334529e+00 1.0103495e+00 9.7453707e-01 9.4891161e-01\n",
      " 8.9989531e-01 8.6691415e-01 9.4533086e-01 1.0132674e+00 3.7138328e-01\n",
      " 0.0000000e+00 3.6879297e-07 7.4791992e-01 9.7683603e-01 9.4203979e-01\n",
      " 9.7781181e-01 9.8652631e-01 9.8968637e-01 1.0024040e+00 1.0043759e+00\n",
      " 1.0089241e+00 1.0060288e+00 1.0058736e+00 9.9735051e-01 1.0018818e+00\n",
      " 9.9741471e-01 1.0021681e+00 1.0169861e+00 1.0182829e+00 1.0089793e+00\n",
      " 9.8126775e-01 9.4765019e-01 9.2369431e-01 9.4647449e-01 9.1423219e-01\n",
      " 8.3075094e-01 9.2752743e-01 1.8719376e-07 9.7322207e-08 1.8998981e-07\n",
      " 9.7220099e-01 1.0356858e+00 9.3653607e-01 9.7871840e-01 9.8335409e-01\n",
      " 9.8107022e-01 9.9268007e-01 1.0018334e+00 1.0037558e+00 1.0010437e+00\n",
      " 1.0026059e+00 1.0023873e+00 9.9690431e-01 9.9614894e-01 1.0052675e+00\n",
      " 1.0159725e+00 1.0114850e+00 9.9292845e-01 9.5680875e-01 9.1042578e-01\n",
      " 8.9261472e-01 9.4262779e-01 9.1356033e-01 5.7761925e-01 1.0734981e+00\n",
      " 9.7322207e-08 9.6857548e-08 8.6610420e-08 4.2267248e-01 1.0085765e+00\n",
      " 9.1185874e-01 9.7253478e-01 9.6825254e-01 9.7799748e-01 9.8459119e-01\n",
      " 9.9474281e-01 1.0042838e+00 1.0026510e+00 1.0041611e+00 1.0042924e+00\n",
      " 1.0015846e+00 1.0012909e+00 1.0152729e+00 1.0144143e+00 1.0032771e+00\n",
      " 9.7161514e-01 9.1166180e-01 8.6349279e-01 8.6388534e-01 9.1234052e-01\n",
      " 1.0282832e+00 1.0408180e+00 2.0552683e-01 9.6857548e-08 0.0000000e+00\n",
      " 0.0000000e+00 2.2969939e-01 8.7106019e-01 8.4800917e-01 9.5905763e-01\n",
      " 9.5409286e-01 9.6251649e-01 9.6252888e-01 9.5984608e-01 9.7493786e-01\n",
      " 9.9090856e-01 9.9993831e-01 1.0012022e+00 1.0072546e+00 1.0190736e+00\n",
      " 1.0417241e+00 1.0406396e+00 1.0053297e+00 9.4491112e-01 8.8543069e-01\n",
      " 8.6751318e-01 8.5332257e-01 8.6996138e-01 9.2686403e-01 1.4954894e+00\n",
      " 1.7508864e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00 3.0081540e-07\n",
      " 5.2966851e-01 6.6614431e-01 8.6142522e-01 8.5901976e-01 8.8141215e-01\n",
      " 9.1684020e-01 9.1267139e-01 9.4329536e-01 9.7206110e-01 1.0004144e+00\n",
      " 1.0099839e+00 1.0359678e+00 1.0626701e+00 1.0976437e+00 1.0580245e+00\n",
      " 9.8248208e-01 9.3355733e-01 8.5958636e-01 7.3327971e-01 6.4231306e-01\n",
      " 5.7512474e-01 3.4111878e-01 3.2441661e-02 1.8766239e-07 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 2.3748726e-08 4.4516858e-07 1.0235959e+00\n",
      " 7.4200058e-01 6.6580683e-01 7.9511470e-01 8.6733103e-01 9.1189891e-01\n",
      " 9.3892008e-01 1.0127660e+00 1.0543573e+00 1.0565462e+00 1.0885599e+00\n",
      " 1.0790602e+00 1.1105967e+00 1.0336881e+00 8.7358195e-01 7.9610276e-01\n",
      " 8.3847237e-01 8.3425897e-01 3.5514408e-01 3.9943677e-01 9.7599399e-01\n",
      " 1.7106270e+00 9.7322207e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 0.0000000e+00 9.7322207e-08 2.7383116e-01 8.0129248e-01 8.1639725e-01\n",
      " 8.0921692e-01 9.0992832e-01 9.1669482e-01 9.1026628e-01 9.9220777e-01\n",
      " 1.0592594e+00 1.0637164e+00 1.0289983e+00 1.0634072e+00 1.1291505e+00\n",
      " 9.7389418e-01 9.0384561e-01 8.7397712e-01 7.5759560e-01 1.1151954e+00\n",
      " 3.5504842e-01 1.9224802e+00 1.3503920e-07 9.7322207e-08 0.0000000e+00\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n",
      " 4.4703484e-08 3.0733645e-08 9.1169477e-02 9.1856825e-01 1.1865302e+00\n",
      " 8.1469429e-01 5.5920196e-01 6.2081337e-01 8.0629909e-01 9.5259869e-01\n",
      " 8.4624058e-01 1.0237356e+00 9.6282297e-01 8.8217759e-01 9.8680794e-01\n",
      " 1.1329248e+00 7.2375351e-01 3.6134253e-07 3.9208430e-07 1.2992055e-07\n",
      " 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]\n",
      "Train data shape:  (4000, 784)\n",
      "Train labels shape:  (4000,)\n",
      "Test data shape:  (6000, 784)\n",
      "Test labels shape:  (6000,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_mnist(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T10:44:29.738814Z",
     "iopub.status.busy": "2025-01-20T10:44:29.737937Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Starting repetition 1/5\n",
      "Loading dataset: MNIST for repetition 1/5\n",
      "Loading dataset: MNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on MNIST (repetition 1)\n",
      "Repetition 1: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02423469387755106, 0.08290816326530615, 0.08545918367346939, 0.09183673469387754, 0.125, 0.22959183673469385, 0.3533163265306123, 0.4119897959183674, 0.45280612244897955, 0.4693877551020408, 0.4961734693877551, 0.5076530612244898, 0.5204081632653061, 0.5242346938775511, 0.5357142857142857, 0.5408163265306123, 0.5484693877551021, 0.5497448979591837, 0.5522959183673469, 0.5573979591836735, 0.5701530612244898, 0.5714285714285714, 0.5714285714285714, 0.5727040816326531, 0.5829081632653061, 0.5854591836734694, 0.5892857142857143, 0.5956632653061225, 0.5982142857142857, 0.6020408163265306, 0.6045918367346939, 0.6084183673469388, 0.6096938775510203, 0.6122448979591837, 0.6135204081632653, 0.6147959183673469, 0.6160714285714286, 0.6237244897959184, 0.6275510204081632, 0.6288265306122449, 0.6352040816326531, 0.6377551020408163, 0.6390306122448979, 0.6390306122448979, 0.6441326530612245, 0.6466836734693877, 0.6466836734693877, 0.6517857142857143, 0.6492346938775511, 0.6568877551020409, 0.6581632653061225, 0.6568877551020409, 0.659438775510204, 0.6645408163265306, 0.6709183673469388, 0.6734693877551021, 0.6772959183673469, 0.6772959183673469, 0.6836734693877551, 0.6887755102040816, 0.6926020408163265, 0.6951530612244898, 0.7002551020408163, 0.7053571428571428, 0.7066326530612245, 0.7079081632653061, 0.7130102040816326, 0.7181122448979591, 0.721938775510204, 0.7232142857142857, 0.7295918367346939, 0.7295918367346939, 0.7308673469387755, 0.7321428571428572, 0.7346938775510203, 0.7359693877551021, 0.7397959183673469, 0.7423469387755102, 0.7474489795918368, 0.7487244897959184, 0.7551020408163265, 0.7576530612244898, 0.7589285714285714, 0.7602040816326531, 0.7627551020408163, 0.7678571428571428, 0.7653061224489796, 0.7691326530612245, 0.7691326530612245, 0.7767857142857143, 0.7755102040816326, 0.7806122448979592, 0.7767857142857143, 0.784438775510204, 0.7869897959183674, 0.7908163265306123, 0.7946428571428572, 0.8035714285714286, 0.7997448979591837, 0.8086734693877551, 0.8035714285714286, 0.8048469387755102, 0.8061224489795918, 0.8112244897959184, 0.8112244897959184, 0.8137755102040816, 0.8150510204081632, 0.8163265306122449, 0.8163265306122449, 0.8201530612244898, 0.8201530612244898, 0.8227040816326531, 0.8252551020408163, 0.8278061224489796, 0.8316326530612245, 0.8329081632653061, 0.8367346938775511, 0.8392857142857143, 0.840561224489796, 0.8418367346938775, 0.8431122448979592, 0.8443877551020408, 0.8443877551020408, 0.846938775510204, 0.846938775510204, 0.8494897959183674, 0.8482142857142857, 0.8494897959183674, 0.8507653061224489, 0.8520408163265306, 0.8533163265306123, 0.8533163265306123, 0.8558673469387755, 0.8571428571428572, 0.8584183673469388, 0.8596938775510204, 0.8622448979591837, 0.8647959183673469, 0.8647959183673469, 0.8647959183673469, 0.8673469387755102, 0.8686224489795918, 0.8698979591836735, 0.8724489795918368, 0.8737244897959184, 0.8724489795918368, 0.8737244897959184, 0.8737244897959184, 0.8737244897959184, 0.8762755102040817, 0.8813775510204082, 0.8826530612244898, 0.8839285714285714, 0.889030612244898, 0.8915816326530612, 0.8941326530612245, 0.8941326530612245, 0.8954081632653061, 0.9005102040816326, 0.9043367346938775, 0.9056122448979592, 0.909438775510204, 0.9068877551020408, 0.9107142857142857, 0.9158163265306123, 0.9145408163265306, 0.9170918367346939, 0.9183673469387755, 0.9196428571428571, 0.9247448979591837, 0.9247448979591837, 0.9260204081632653, 0.9285714285714286, 0.9285714285714286, 0.9285714285714286, 0.9298469387755102, 0.9311224489795918, 0.9323979591836735, 0.9336734693877551, 0.9349489795918368, 0.9375, 0.9387755102040817, 0.9413265306122449, 0.9400510204081632, 0.9438775510204082, 0.9489795918367347, 0.9489795918367347, 0.951530612244898, 0.951530612244898, 0.9528061224489796, 0.9553571428571429, 0.9566326530612245, 0.9579081632653061, 0.9591836734693877, 0.9630102040816326, 0.9630102040816326, 0.9642857142857143, 0.9681122448979592, 0.9706632653061225, 0.9744897959183674, 0.9757653061224489, 0.9808673469387755, 0.9808673469387755, 0.9808673469387755, 0.9821428571428571, 0.9821428571428571, 0.9846938775510204, 0.9846938775510204, 0.9846938775510204, 0.9846938775510204, 0.9834183673469388, 0.9834183673469388, 0.9846938775510204, 0.985969387755102, 0.9885204081632653, 0.9885204081632653, 0.9910714285714286, 0.9936224489795918, 0.9948979591836735, 0.9974489795918368, 0.9974489795918368, 0.9974489795918368, 0.9987244897959183, 1.0] and test accuracy = [0.9291666666666667, 0.9296666666666666, 0.9296666666666666, 0.9293333333333333, 0.9293333333333333, 0.9296666666666666, 0.9296666666666666, 0.9296666666666666, 0.9296666666666666, 0.9296666666666666, 0.9296666666666666, 0.9295, 0.9293333333333333, 0.9293333333333333, 0.9293333333333333, 0.9293333333333333, 0.9291666666666667, 0.9296666666666666, 0.9296666666666666, 0.9298333333333333, 0.9298333333333333, 0.93, 0.93, 0.9301666666666667, 0.93, 0.9301666666666667, 0.9301666666666667, 0.9301666666666667, 0.93, 0.9296666666666666, 0.9298333333333333, 0.9298333333333333, 0.9298333333333333, 0.9303333333333333, 0.9303333333333333, 0.9303333333333333, 0.9303333333333333, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.93, 0.9298333333333333, 0.93, 0.93, 0.9293333333333333, 0.93, 0.9301666666666667, 0.93, 0.9303333333333333, 0.93, 0.9291666666666667, 0.93, 0.9296666666666666, 0.9296666666666666, 0.929, 0.9288333333333333, 0.9265, 0.9243333333333333, 0.9235, 0.92, 0.9163333333333333, 0.9131666666666667, 0.9106666666666666, 0.9103333333333333, 0.9086666666666666, 0.908, 0.9075, 0.906, 0.9056666666666666, 0.9055, 0.9043333333333333, 0.9036666666666666, 0.9025, 0.902, 0.9016666666666666, 0.9006666666666666, 0.8993333333333333, 0.8978333333333334, 0.8975, 0.8963333333333333, 0.896, 0.8955, 0.8945, 0.8945, 0.8931666666666667, 0.8928333333333334, 0.8925, 0.8915, 0.8915, 0.8908333333333334, 0.8901666666666667, 0.89, 0.8888333333333334, 0.8886666666666667, 0.888, 0.8866666666666667, 0.8866666666666667, 0.886, 0.8848333333333334, 0.8841666666666667, 0.8835, 0.8825, 0.8813333333333333, 0.8808333333333334, 0.8798333333333334, 0.8793333333333333, 0.879, 0.8781666666666667, 0.8776666666666667, 0.8768333333333334, 0.8766666666666667, 0.8758333333333334, 0.8755, 0.8748333333333334, 0.8743333333333333, 0.8728333333333333, 0.8726666666666667, 0.8723333333333333, 0.872, 0.871, 0.8708333333333333, 0.87, 0.8703333333333333, 0.869, 0.8685, 0.8685, 0.868, 0.867, 0.866, 0.865, 0.8646666666666667, 0.863, 0.8628333333333333, 0.863, 0.8623333333333333, 0.8605, 0.86, 0.8591666666666666, 0.8585, 0.857, 0.8568333333333333, 0.8553333333333333, 0.8553333333333333, 0.8543333333333333, 0.854, 0.8528333333333333, 0.8523333333333334, 0.8506666666666667, 0.8496666666666667, 0.8483333333333334, 0.8476666666666667, 0.846, 0.8453333333333334, 0.8443333333333334, 0.8436666666666667, 0.8426666666666667, 0.8413333333333334, 0.8398333333333333, 0.8378333333333333, 0.8365, 0.835, 0.8338333333333333, 0.833, 0.8318333333333333, 0.8311666666666667, 0.8296666666666667, 0.8285, 0.8271666666666667, 0.8256666666666667, 0.8245, 0.8223333333333334, 0.8211666666666667, 0.8201666666666667, 0.8188333333333333, 0.817, 0.816, 0.8143333333333334, 0.8115, 0.8111666666666667, 0.8101666666666667, 0.8093333333333333, 0.8085, 0.808, 0.8065, 0.8041666666666667, 0.8025, 0.8006666666666666, 0.7996666666666666, 0.7976666666666666, 0.7951666666666667, 0.7923333333333333, 0.7908333333333334, 0.7878333333333334, 0.7853333333333333, 0.7831666666666667, 0.7811666666666667, 0.7793333333333333, 0.7778333333333334, 0.7748333333333334, 0.7728333333333334, 0.7695, 0.767, 0.7653333333333333, 0.7628333333333334, 0.7605, 0.7576666666666667, 0.7558333333333334, 0.753, 0.7493333333333333, 0.7465, 0.7423333333333333, 0.7398333333333333, 0.7361666666666666, 0.7318333333333333, 0.7278333333333333, 0.723, 0.7203333333333334, 0.7168333333333333, 0.712, 0.7083333333333334, 0.7038333333333333, 0.7008333333333333, 0.6955, 0.6921666666666667, 0.689, 0.6843333333333333, 0.6806666666666666, 0.676, 0.671, 0.6661666666666667, 0.6628333333333334, 0.6585, 0.6541666666666667, 0.6485, 0.6415, 0.6343333333333333, 0.6278333333333334, 0.6205, 0.6131666666666666, 0.6051666666666666, 0.5981666666666666, 0.5915, 0.5848333333333333, 0.5793333333333334, 0.5741666666666667, 0.569, 0.5616666666666666, 0.5585, 0.553, 0.5471666666666667, 0.5435, 0.5381666666666667, 0.5323333333333333, 0.5283333333333333, 0.5265, 0.522, 0.5163333333333333, 0.513, 0.5031666666666667, 0.49966666666666665, 0.495, 0.48733333333333334, 0.4765, 0.4671666666666667, 0.457, 0.41, 0.39866666666666667, 0.3685, 0.304, 0.27066666666666667, 0.25316666666666665, 0.18866666666666668, 0.184, 0.18183333333333335, 0.17866666666666667, 0.094, 0.094]\n",
      "Results successfully saved to results/input_sparsity/MNIST/LassoNet/rep_1/MNIST_LassoNet_rep1_res.csv\n",
      "Loading dataset: MNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on MNIST (repetition 1)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 2.0416 - accuracy: 0.2100\n",
      "test acc for vanilla model is [2.0415728092193604, 0.20999999344348907]\n",
      "2.0415728092193604 0.20999999344348907\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.3576 - accuracy: 0.5240\n",
      "test acc for vanilla model is [1.3576191663742065, 0.5239999890327454]\n",
      "1.3576191663742065 0.5239999890327454\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.4268 - accuracy: 0.6068\n",
      "test acc for vanilla model is [1.426792860031128, 0.6068333387374878]\n",
      "1.426792860031128 0.6068333387374878\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.3753 - accuracy: 0.7265\n",
      "test acc for vanilla model is [1.3752996921539307, 0.7264999747276306]\n",
      "1.3752996921539307 0.7264999747276306\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3700 - accuracy: 0.9230\n",
      "test acc for vanilla model is [0.37002885341644287, 0.9229999780654907]\n",
      "0.37002885341644287 0.9229999780654907\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 0.3152 - accuracy: 0.9358\n",
      "test acc for vanilla model is [0.31516531109809875, 0.9358333349227905]\n",
      "0.31516531109809875 0.9358333349227905\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.2936 - accuracy: 0.9413\n",
      "test acc for vanilla model is [0.2935599088668823, 0.9413333535194397]\n",
      "0.2935599088668823 0.9413333535194397\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3054 - accuracy: 0.9372\n",
      "test acc for vanilla model is [0.3054382801055908, 0.937166690826416]\n",
      "0.3054382801055908 0.937166690826416\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3001 - accuracy: 0.9375\n",
      "test acc for vanilla model is [0.30005648732185364, 0.9375]\n",
      "0.30005648732185364 0.9375\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3114 - accuracy: 0.9390\n",
      "test acc for vanilla model is [0.31135350465774536, 0.9390000104904175]\n",
      "0.31135350465774536 0.9390000104904175\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.2959 - accuracy: 0.9378\n",
      "test acc for vanilla model is [0.2959471344947815, 0.937833309173584]\n",
      "0.2959471344947815 0.937833309173584\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3146 - accuracy: 0.9375\n",
      "test acc for vanilla model is [0.3146422207355499, 0.9375]\n",
      "0.3146422207355499 0.9375\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 262\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 0.3064 - accuracy: 0.9400\n",
      "test acc for vanilla model is [0.3064430058002472, 0.9399999976158142]\n",
      "0.3064430058002472 0.9399999976158142\n",
      "Repetition 1: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.20999999344348907, 0.5239999890327454, 0.6068333387374878, 0.7264999747276306, 0.9229999780654907, 0.9358333349227905, 0.9413333535194397, 0.937166690826416, 0.9375, 0.9390000104904175, 0.937833309173584, 0.9375, 0.9399999976158142]\n",
      "Results successfully saved to results/input_sparsity/MNIST/HSIC_dnn/rep_1/MNIST_HSIC_dnn_rep1_res.csv\n",
      "Loading dataset: MNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on MNIST (repetition 1)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 1: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.2095, 0.5165, 0.6026666666666667, 0.732, 0.9368333333333333, 0.9518333333333333, 0.9528333333333333, 0.9528333333333333, 0.9528333333333333, 0.9528333333333333, 0.9528333333333333, 0.9528333333333333, 0.9528333333333333]\n",
      "Results successfully saved to results/input_sparsity/MNIST/HSIC_svm/rep_1/MNIST_HSIC_svm_rep1_res.csv\n",
      "\n",
      "Starting repetition 2/5\n",
      "Loading dataset: MNIST for repetition 2/5\n",
      "Loading dataset: MNIST with one_hot = False for repetition 2\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on MNIST (repetition 2)\n",
      "Repetition 2: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.008928571428571397, 0.07015306122448983, 0.08545918367346939, 0.08673469387755106, 0.11096938775510201, 0.20408163265306123, 0.32908163265306123, 0.40306122448979587, 0.4387755102040817, 0.46683673469387754, 0.48341836734693877, 0.49744897959183676, 0.5051020408163265, 0.5127551020408163, 0.5204081632653061, 0.5242346938775511, 0.5255102040816326, 0.5306122448979591, 0.5357142857142857, 0.5420918367346939, 0.5497448979591837, 0.5561224489795918, 0.5625, 0.5663265306122449, 0.5727040816326531, 0.5765306122448979, 0.5829081632653061, 0.5892857142857143, 0.5931122448979591, 0.5943877551020409, 0.5982142857142857, 0.6045918367346939, 0.6071428571428572, 0.6071428571428572, 0.6084183673469388, 0.6122448979591837, 0.6160714285714286, 0.6224489795918368, 0.6224489795918368, 0.6262755102040816, 0.6301020408163265, 0.6339285714285714, 0.6326530612244898, 0.6403061224489797, 0.6428571428571428, 0.6415816326530612, 0.6492346938775511, 0.6505102040816326, 0.653061224489796, 0.659438775510204, 0.6683673469387755, 0.6734693877551021, 0.6760204081632653, 0.6772959183673469, 0.6823979591836735, 0.6862244897959184, 0.6875, 0.6951530612244898, 0.6977040816326531, 0.7002551020408163, 0.6964285714285714, 0.7040816326530612, 0.7053571428571428, 0.7104591836734694, 0.7130102040816326, 0.7206632653061225, 0.7206632653061225, 0.7232142857142857, 0.7193877551020409, 0.7244897959183674, 0.7232142857142857, 0.7270408163265306, 0.7232142857142857, 0.7295918367346939, 0.7321428571428572, 0.7397959183673469, 0.7397959183673469, 0.7423469387755102, 0.7410714285714286, 0.7461734693877551, 0.7474489795918368, 0.7538265306122449, 0.7563775510204082, 0.7576530612244898, 0.7627551020408163, 0.7653061224489796, 0.764030612244898, 0.7678571428571428, 0.7678571428571428, 0.7716836734693877, 0.7653061224489796, 0.7729591836734694, 0.7704081632653061, 0.7793367346938775, 0.7755102040816326, 0.784438775510204, 0.7767857142857143, 0.7831632653061225, 0.7806122448979592, 0.7882653061224489, 0.7857142857142857, 0.7869897959183674, 0.7946428571428572, 0.7959183673469388, 0.8035714285714286, 0.8048469387755102, 0.8073979591836735, 0.8099489795918368, 0.8112244897959184, 0.8137755102040816, 0.8163265306122449, 0.8188775510204082, 0.8214285714285714, 0.8239795918367347, 0.8227040816326531, 0.8239795918367347, 0.826530612244898, 0.8303571428571428, 0.8316326530612245, 0.8329081632653061, 0.8354591836734694, 0.840561224489796, 0.8431122448979592, 0.8482142857142857, 0.8494897959183674, 0.8507653061224489, 0.8507653061224489, 0.8507653061224489, 0.8520408163265306, 0.8507653061224489, 0.8520408163265306, 0.8533163265306123, 0.8558673469387755, 0.8558673469387755, 0.8571428571428572, 0.8584183673469388, 0.8584183673469388, 0.860969387755102, 0.8647959183673469, 0.8647959183673469, 0.8660714285714286, 0.8660714285714286, 0.8686224489795918, 0.8724489795918368, 0.8711734693877551, 0.8724489795918368, 0.875, 0.875, 0.8762755102040817, 0.8775510204081632, 0.8801020408163265, 0.8852040816326531, 0.8877551020408163, 0.8877551020408163, 0.889030612244898, 0.8903061224489796, 0.8915816326530612, 0.8954081632653061, 0.8966836734693877, 0.8992346938775511, 0.8992346938775511, 0.9017857142857143, 0.9043367346938775, 0.9056122448979592, 0.9081632653061225, 0.909438775510204, 0.9081632653061225, 0.9107142857142857, 0.9158163265306123, 0.9145408163265306, 0.9170918367346939, 0.923469387755102, 0.9247448979591837, 0.9247448979591837, 0.9247448979591837, 0.9260204081632653, 0.9272959183673469, 0.9285714285714286, 0.9323979591836735, 0.9323979591836735, 0.9336734693877551, 0.9349489795918368, 0.9375, 0.9413265306122449, 0.9413265306122449, 0.9438775510204082, 0.9451530612244898, 0.9477040816326531, 0.9489795918367347, 0.9489795918367347, 0.951530612244898, 0.9540816326530612, 0.9566326530612245, 0.9566326530612245, 0.9591836734693877, 0.9630102040816326, 0.9642857142857143, 0.9668367346938775, 0.9681122448979592, 0.9681122448979592, 0.9719387755102041, 0.9719387755102041, 0.9744897959183674, 0.9744897959183674, 0.9808673469387755, 0.9821428571428571, 0.9821428571428571, 0.9834183673469388, 0.9846938775510204, 0.9846938775510204, 0.9846938775510204, 0.985969387755102, 0.9872448979591837, 0.985969387755102, 0.9885204081632653, 0.9872448979591837, 0.9897959183673469, 0.9923469387755102, 0.9936224489795918, 0.9936224489795918, 0.9948979591836735, 0.9961734693877551, 0.9974489795918368, 0.9974489795918368, 0.9974489795918368, 1.0] and test accuracy = [0.9276666666666666, 0.9285, 0.9278333333333333, 0.9278333333333333, 0.928, 0.928, 0.9276666666666666, 0.9276666666666666, 0.9281666666666667, 0.928, 0.928, 0.9281666666666667, 0.928, 0.928, 0.928, 0.928, 0.9278333333333333, 0.9276666666666666, 0.9276666666666666, 0.9276666666666666, 0.9276666666666666, 0.9276666666666666, 0.9275, 0.9273333333333333, 0.9275, 0.9276666666666666, 0.9276666666666666, 0.9273333333333333, 0.9273333333333333, 0.9273333333333333, 0.9271666666666667, 0.9271666666666667, 0.9271666666666667, 0.9271666666666667, 0.9271666666666667, 0.927, 0.927, 0.9271666666666667, 0.9273333333333333, 0.9275, 0.9275, 0.9276666666666666, 0.9276666666666666, 0.9276666666666666, 0.9276666666666666, 0.9278333333333333, 0.9276666666666666, 0.9278333333333333, 0.9283333333333333, 0.9278333333333333, 0.9276666666666666, 0.9271666666666667, 0.9268333333333333, 0.9273333333333333, 0.9268333333333333, 0.9268333333333333, 0.9268333333333333, 0.9271666666666667, 0.9263333333333333, 0.9273333333333333, 0.9268333333333333, 0.9273333333333333, 0.9266666666666666, 0.9265, 0.9251666666666667, 0.9223333333333333, 0.9183333333333333, 0.9141666666666667, 0.912, 0.9101666666666667, 0.9091666666666667, 0.9083333333333333, 0.9071666666666667, 0.907, 0.9063333333333333, 0.9048333333333334, 0.904, 0.904, 0.903, 0.9011666666666667, 0.9006666666666666, 0.8998333333333334, 0.8998333333333334, 0.8995, 0.8988333333333334, 0.898, 0.8976666666666666, 0.8966666666666666, 0.896, 0.8951666666666667, 0.8945, 0.8933333333333333, 0.8925, 0.8921666666666667, 0.892, 0.8911666666666667, 0.8903333333333333, 0.8896666666666667, 0.8893333333333333, 0.8886666666666667, 0.8875, 0.8868333333333334, 0.8866666666666667, 0.8865, 0.8861666666666667, 0.8865, 0.8856666666666667, 0.885, 0.8845, 0.8836666666666667, 0.8836666666666667, 0.883, 0.883, 0.8823333333333333, 0.8816666666666667, 0.8803333333333333, 0.8801666666666667, 0.8795, 0.8786666666666667, 0.8778333333333334, 0.8775, 0.8763333333333333, 0.8756666666666667, 0.8748333333333334, 0.8741666666666666, 0.8721666666666666, 0.8718333333333333, 0.871, 0.8705, 0.869, 0.8683333333333333, 0.8673333333333333, 0.8671666666666666, 0.8663333333333333, 0.866, 0.8646666666666667, 0.864, 0.8636666666666667, 0.863, 0.8611666666666666, 0.8608333333333333, 0.8588333333333333, 0.8578333333333333, 0.8565, 0.8556666666666667, 0.8551666666666666, 0.855, 0.854, 0.8526666666666667, 0.851, 0.8506666666666667, 0.8496666666666667, 0.8493333333333334, 0.848, 0.8475, 0.8468333333333333, 0.8465, 0.8463333333333334, 0.8456666666666667, 0.8443333333333334, 0.844, 0.8423333333333334, 0.8423333333333334, 0.841, 0.8408333333333333, 0.839, 0.8375, 0.8376666666666667, 0.8361666666666666, 0.8351666666666666, 0.8335, 0.8321666666666667, 0.831, 0.8293333333333334, 0.828, 0.8265, 0.8243333333333334, 0.8228333333333333, 0.8225, 0.8213333333333334, 0.8205, 0.8188333333333333, 0.817, 0.8165, 0.8143333333333334, 0.813, 0.8121666666666667, 0.81, 0.8083333333333333, 0.8076666666666666, 0.8061666666666667, 0.8041666666666667, 0.8033333333333333, 0.8013333333333333, 0.7998333333333333, 0.7975, 0.7956666666666666, 0.7928333333333333, 0.7906666666666666, 0.7876666666666666, 0.7856666666666666, 0.7828333333333334, 0.7808333333333334, 0.7778333333333334, 0.7756666666666666, 0.7745, 0.7716666666666666, 0.7683333333333333, 0.765, 0.7623333333333333, 0.7596666666666667, 0.7578333333333334, 0.756, 0.7543333333333333, 0.751, 0.7488333333333334, 0.747, 0.743, 0.7401666666666666, 0.7375, 0.734, 0.731, 0.7268333333333333, 0.7223333333333334, 0.7191666666666666, 0.7161666666666666, 0.7128333333333333, 0.7093333333333334, 0.7048333333333333, 0.6986666666666667, 0.6948333333333333, 0.6901666666666667, 0.6851666666666667, 0.6821666666666667, 0.6773333333333333, 0.6726666666666666, 0.6676666666666666, 0.6626666666666666, 0.6576666666666666, 0.6521666666666667, 0.6458333333333334, 0.6401666666666667, 0.6338333333333334, 0.6275, 0.6196666666666667, 0.6131666666666666, 0.606, 0.5993333333333334, 0.593, 0.5886666666666667, 0.5836666666666667, 0.578, 0.5736666666666667, 0.5685, 0.5625, 0.5576666666666666, 0.5533333333333333, 0.5495, 0.5456666666666666, 0.5425, 0.5383333333333333, 0.5326666666666666, 0.5283333333333333, 0.5221666666666667, 0.517, 0.5096666666666667, 0.5033333333333333, 0.49883333333333335, 0.49316666666666664, 0.4846666666666667, 0.47583333333333333, 0.4686666666666667, 0.45916666666666667, 0.44533333333333336, 0.38383333333333336, 0.325, 0.30083333333333334, 0.272, 0.23683333333333334, 0.18866666666666668, 0.1845, 0.181, 0.17483333333333334, 0.094, 0.094]\n",
      "Results successfully saved to results/input_sparsity/MNIST/LassoNet/rep_2/MNIST_LassoNet_rep2_res.csv\n",
      "Loading dataset: MNIST with one_hot = False for repetition 2\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on MNIST (repetition 2)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 2.0415 - accuracy: 0.2088\n",
      "test acc for vanilla model is [2.041537284851074, 0.20883333683013916]\n",
      "2.041537284851074 0.20883333683013916\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.3518 - accuracy: 0.5238\n",
      "test acc for vanilla model is [1.3517793416976929, 0.5238333344459534]\n",
      "1.3517793416976929 0.5238333344459534\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.4091 - accuracy: 0.6037\n",
      "test acc for vanilla model is [1.40913724899292, 0.6036666631698608]\n",
      "1.40913724899292 0.6036666631698608\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 1s 2ms/step - loss: 1.3736 - accuracy: 0.7285\n",
      "test acc for vanilla model is [1.3735713958740234, 0.7285000085830688]\n",
      "1.3735713958740234 0.7285000085830688\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_dnn, hsic_svm, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/MNIST/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"LassoNet\", lassoNet, False),\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"MNIST\": load_mnist\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'MNIST')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
