{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T00:02:07.169949Z",
     "iopub.status.busy": "2025-01-24T00:02:07.169289Z",
     "iopub.status.idle": "2025-01-24T00:02:10.358626Z",
     "shell.execute_reply": "2025-01-24T00:02:10.357898Z",
     "shell.execute_reply.started": "2025-01-24T00:02:07.169878Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn, hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.1 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 2\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123 #42\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/madelon/'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    4e-4,\n",
    "    5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    #1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    #1.1e-1,\n",
    "    1.2e-1,\n",
    "    #1.3e-1,\n",
    "    1.4e-1,\n",
    "    1.5e-1,\n",
    "    #1.6e-1,\n",
    "    #1.7e-1,\n",
    "    1.8e-1,\n",
    "    1.9e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T00:02:14.476801Z",
     "iopub.status.busy": "2025-01-24T00:02:14.476005Z",
     "iopub.status.idle": "2025-01-24T00:02:14.542172Z",
     "shell.execute_reply": "2025-01-24T00:02:14.541454Z",
     "shell.execute_reply.started": "2025-01-24T00:02:14.476775Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0. 1.]\n",
      "Unique test labels after processing: [0. 1.]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080, 2)\n",
      "X_test shape: (520, 500), Y_test shape: (520, 2)\n",
      "Normalized Training Set Mean and SD: [ 1.99446308e-08  1.98514982e-08 -1.87410762e-08 -4.58497268e-10\n",
      " -2.00592556e-10  9.05532094e-09  1.10863203e-08  9.05532094e-09\n",
      " -2.19218999e-09 -1.01915347e-07  7.62251684e-09 -2.49307881e-08\n",
      " -6.07508888e-09  2.05177528e-08 -3.72529030e-08 -1.40414791e-09\n",
      "  6.64821043e-09 -1.11185585e-08  4.39870806e-09 -1.34683571e-08\n",
      "  3.20948090e-09 -1.50730983e-08  2.82262373e-08  5.98912075e-09\n",
      "  1.05454374e-08  1.28379236e-08  2.04031281e-08 -2.08043129e-08\n",
      "  2.57904720e-09  4.64228478e-09 -6.97417306e-08 -2.00592565e-09\n",
      " -1.05454374e-08  1.19209291e-08 -8.74010375e-09 -1.56032343e-08\n",
      "  1.57178590e-08  4.01185130e-09 -6.73417877e-10  8.87263862e-09\n",
      " -8.76875994e-09 -1.30098599e-08  6.16105700e-09  3.00172420e-09\n",
      " -6.13240081e-09 -1.37405900e-08 -1.77094570e-08  1.58754681e-08\n",
      "  1.66205254e-08 -2.63635935e-09 -7.73714159e-09 -1.04881250e-08\n",
      " -2.65068723e-09 -1.62050124e-08  3.32410508e-08 -1.32032882e-08\n",
      " -8.08101408e-09 -1.37907383e-08 -2.22944294e-08 -2.18359322e-08\n",
      " -7.07805148e-09 -4.29841190e-10  1.34683571e-08  9.50951957e-08\n",
      " -1.43853516e-08 -1.98873185e-08  2.59624073e-08  1.43423673e-08\n",
      " -2.34406734e-08  5.03200752e-08 -3.20948090e-09 -7.67982922e-09\n",
      "  2.33547048e-08  7.24998817e-09  1.39340184e-09 -1.56892028e-08\n",
      "  3.10631911e-08 -1.35256695e-08 -9.39919431e-09 -6.73417855e-09\n",
      " -7.04939529e-09 -7.96459876e-09  6.47627374e-09  7.46490869e-09\n",
      " -1.22934578e-08 -2.67934341e-09  2.62203126e-09 -8.48219983e-09\n",
      "  4.94603931e-08 -9.54247437e-09  9.39847737e-08 -6.64821043e-09\n",
      "  1.36116380e-08 -1.20928654e-08 -2.80829582e-09  6.71985045e-09\n",
      " -1.16236221e-08 -4.75690909e-09  5.27271871e-09  8.65413607e-09\n",
      " -3.20948090e-09  1.51877213e-08 -1.26086752e-09 -1.89130134e-09\n",
      " -4.40408110e-09 -2.16639950e-08  2.86560792e-11 -1.26659874e-08\n",
      "  1.99732870e-08  7.90907784e-09 -8.23862312e-09  4.29841197e-11\n",
      "  4.92884578e-09  6.67686662e-09  1.83398907e-09  3.26679306e-09\n",
      " -2.19621974e-08  9.16815424e-09 -1.94861349e-09 -3.78260268e-09\n",
      " -4.49398954e-08  7.53654916e-09  1.12940768e-08  6.41896181e-09\n",
      " -2.06896900e-08  2.80829582e-09  1.80533299e-09  2.24663665e-08\n",
      " -1.63912777e-08 -2.88970821e-08 -1.56032343e-08 -3.55335383e-09\n",
      "  1.98586623e-08 -2.77784871e-08 -1.13191512e-09 -3.08231951e-09\n",
      "  6.53358612e-09  2.23517427e-09 -9.28456956e-09 -2.75384924e-08\n",
      " -5.12943821e-09 -4.12647561e-09  8.48219983e-09 -2.94297937e-08\n",
      " -3.32410521e-09 -4.69243311e-09  2.65928417e-08  5.38734302e-09\n",
      "  6.44188631e-08  1.17919763e-08  3.61066599e-09 -9.92933114e-09\n",
      " -2.60770316e-09  1.12904956e-08 -4.05235454e-07  7.24998817e-09\n",
      " -6.01150818e-09 -1.01872359e-08 -4.55631657e-08 -1.07460298e-08\n",
      " -2.36412649e-08  1.49011614e-09  1.04881250e-08 -3.61066590e-08\n",
      "  1.48438488e-08 -1.44426640e-08 -1.06027489e-08  2.38705145e-08\n",
      " -2.38217993e-07  4.98615771e-09  3.52613050e-08 -1.13764633e-08\n",
      "  8.59682370e-09 -4.42449846e-08 -1.61333720e-08 -2.16066844e-08\n",
      " -7.18980999e-08  2.34979858e-09 -2.94011375e-08 -8.26727842e-09\n",
      " -4.35572423e-09 -2.70799951e-08 -5.01481390e-09 -2.65570215e-08\n",
      " -2.20365255e-08  1.42420715e-08 -1.81679543e-08  7.96639021e-09\n",
      " -3.89722699e-09 -1.20498811e-08 -1.26803155e-08 -6.44761800e-09\n",
      "  1.91995730e-09 -2.30394885e-08  3.03181302e-08 -2.13487787e-08\n",
      " -4.62795668e-08 -5.50196733e-09 -1.14781926e-07 -1.44713201e-08\n",
      " -8.59682395e-11 -2.02025365e-08 -7.79445397e-09 -6.67686662e-09\n",
      "  4.49900428e-09 -4.96000894e-08 -1.40844634e-08 -6.42755822e-08\n",
      "  3.89722699e-09 -2.04747685e-08 -9.05532094e-09  1.84545144e-08\n",
      " -1.03735003e-08 -4.75690909e-09 -4.06916323e-09 -3.10058788e-08\n",
      "  3.66224704e-08  1.08893099e-08 -5.04347000e-08  3.71382782e-08\n",
      "  8.02370259e-09  4.55918219e-08 -1.54742830e-09  2.41857308e-08\n",
      "  2.00592565e-09  6.07508888e-09  2.18359322e-08 -4.12934114e-08\n",
      " -8.01223976e-08  3.38141746e-08 -3.23813687e-09  1.69643997e-08\n",
      "  2.45009470e-08  4.59930050e-09  2.69653704e-08 -7.80878118e-09\n",
      "  1.64772462e-08 -8.62548006e-08  4.21244373e-09  4.54485409e-08\n",
      "  1.02588764e-08 -3.55335383e-09 -2.02455208e-08 -1.80533295e-08\n",
      " -1.69070873e-08  1.36402933e-08  1.27734472e-08 -4.01758236e-08\n",
      " -2.04031281e-08  1.33394051e-08 -1.60617333e-08 -1.17418288e-08\n",
      "  1.24653941e-08  5.04347009e-09  3.86857080e-09 -3.66439612e-09\n",
      "  1.71936476e-09  2.63635935e-09 -2.96876976e-08  1.48553113e-07\n",
      " -2.05177528e-08  1.15627280e-07 -6.64821043e-09 -4.09781933e-08\n",
      " -5.27271871e-09  1.36976057e-08  2.06323780e-09 -2.02025365e-08\n",
      "  1.91422611e-08 -3.81125842e-09  3.32123946e-08 -7.67982922e-09\n",
      "  1.23937545e-08 -7.43983453e-09 -1.10899023e-08  9.59978674e-09\n",
      " -1.01356555e-07  2.87707032e-08 -9.39919431e-09 -1.79960171e-08\n",
      " -1.40414791e-08 -2.04031281e-08 -1.39841667e-08  2.84268307e-08\n",
      "  1.53668225e-08 -1.17507835e-08 -5.87449644e-09  1.62193405e-08\n",
      " -2.27529267e-08 -1.72509598e-08 -3.03181302e-08 -1.40987906e-08\n",
      "  4.02331359e-08 -2.38705145e-08  1.76808008e-08 -9.86306414e-09\n",
      "  1.60474045e-09 -4.41303616e-09 -6.56224231e-09 -5.47331114e-09\n",
      " -1.49011614e-09 -9.74306680e-09  1.62676983e-08  7.22133198e-09\n",
      " -6.18971319e-09 -1.27519550e-08 -1.63339653e-08 -1.14624317e-10\n",
      "  1.96294145e-08  2.57331596e-08 -5.86375037e-09 -3.99752320e-09\n",
      " -3.61559116e-11  2.57904720e-09 -2.08616253e-08  7.75146969e-09\n",
      "  8.94069707e-09 -3.88289889e-09 -1.85404829e-08  7.62251684e-09\n",
      "  8.96935237e-09  2.12628102e-08  3.43156548e-09 -9.74306680e-09\n",
      " -4.06916323e-09  3.42726700e-08  2.20508536e-08  1.27806112e-08\n",
      "  2.18359322e-08 -6.70552236e-09  4.04050704e-09  3.88289889e-09\n",
      " -1.70560980e-07  4.52766047e-09  7.82310927e-09  2.31899318e-08\n",
      "  4.47034854e-09 -1.37549183e-09  1.26086750e-08  1.67351502e-08\n",
      "  7.42192441e-09  4.04337293e-08  1.95721022e-08  2.00592565e-09\n",
      " -1.07854321e-08  8.36757508e-09 -2.57904725e-10  5.73121606e-10\n",
      " -2.98023228e-09 -1.33537332e-08 -9.91500304e-09  1.81106419e-08\n",
      "  4.90018959e-09 -1.18492887e-08  4.19811563e-09  3.98319511e-09\n",
      " -1.37764102e-08 -3.09485659e-09 -4.87153373e-10  1.78527380e-08\n",
      " -9.74306746e-10 -4.24109992e-09  1.76951287e-09  1.13764633e-08\n",
      "  2.43576679e-08 -8.46070769e-09  2.08687894e-08  3.72529030e-09\n",
      "  2.12054978e-08 -6.76283474e-09  8.16698265e-09 -1.16486962e-08\n",
      " -1.94861349e-09 -9.45650669e-10 -5.44465528e-10  1.06188685e-08\n",
      " -5.33003064e-09 -2.10908748e-08  1.50730983e-08  1.30098599e-08\n",
      " -5.11779685e-09  1.07961782e-08  4.69959716e-09 -3.56481635e-08\n",
      " -1.63339653e-09 -9.16994536e-10  1.31531408e-08 -4.41303616e-09\n",
      " -5.81718407e-09  1.49011612e-08  3.79907981e-08 -4.12647561e-09\n",
      " -1.75375199e-08 -1.48151926e-08  4.29841185e-09 -7.59386065e-09\n",
      " -2.35015669e-08 -1.13764633e-08 -2.49307881e-08 -7.56520535e-09\n",
      "  8.02370259e-09 -1.17400374e-08 -6.76283474e-09  3.61066599e-09\n",
      " -6.69119444e-08  2.06323780e-09  9.29889765e-09  9.91500304e-09\n",
      "  1.68784311e-08  1.10325908e-08  7.76579778e-09  2.04425312e-08\n",
      " -1.46719126e-08  8.55383941e-09 -2.92292013e-09 -3.23383844e-08\n",
      " -1.30671722e-08  1.85977953e-08 -8.68279226e-09 -1.96294137e-09\n",
      "  2.14347473e-08  4.69959716e-09  5.08358831e-08 -1.44412311e-07\n",
      "  3.23813687e-09 -1.10325908e-08  1.46719126e-08 -1.26265853e-09\n",
      " -1.32677647e-08  5.84870570e-08  4.01829858e-08 -6.36164943e-09\n",
      "  7.74287230e-08  1.80533295e-08 -1.71936474e-08  1.03448450e-08\n",
      "  7.45058060e-09 -3.63753117e-09 -1.43280392e-08  5.15809451e-10\n",
      "  1.47005688e-08  6.66683704e-08  1.60474045e-09 -6.05001471e-09\n",
      "  1.98013499e-08 -1.38695420e-08  8.94069707e-09 -2.72949152e-09\n",
      " -6.40463371e-09 -1.34683571e-08  1.07173737e-08 -2.60770316e-09\n",
      " -1.57894995e-08  4.09781942e-09 -2.82262391e-09 -6.13240081e-09\n",
      " -2.86560797e-09 -1.38122305e-08 -6.88462309e-09 -4.06916323e-09\n",
      " -1.12331833e-08  5.96046457e-09  3.00888842e-10 -1.00869402e-08\n",
      " -1.48438488e-08  1.22791297e-08  9.28456956e-09  8.25295121e-09\n",
      "  2.86560797e-09 -6.24702512e-09  8.08101408e-09 -1.83972020e-08\n",
      " -8.02370226e-10 -3.19085451e-08  2.55612225e-08  1.06600613e-08\n",
      "  2.72232747e-09 -8.19563883e-09 -4.51046702e-08 -1.23650983e-08\n",
      "  1.52736899e-08 -2.34406734e-08 -4.12647561e-09  4.16086259e-08\n",
      "  5.73121595e-09  1.10039347e-08 -3.02321652e-08  1.68497749e-08\n",
      " -3.29544919e-10  2.76674452e-08 -2.03458157e-08 -3.28971801e-08\n",
      "  1.57357700e-08 -1.57608433e-08  4.66377692e-09 -5.78852788e-09\n",
      " -1.10039347e-08 -1.88454017e-08 -3.00888825e-09 -5.33003064e-09] [0.99999905 1.0000008  0.99999905 0.99999833 1.0000001  1.0000004\n",
      " 1.0000007  0.99999905 1.0000013  1.0000005  1.0000006  0.99999946\n",
      " 0.9999999  0.9999994  1.0000018  0.99999887 1.0000012  0.9999997\n",
      " 1.0000004  1.0000014  0.9999998  0.9999989  0.99999946 1.0000012\n",
      " 0.9999996  0.9999995  0.9999984  1.         1.0000004  0.9999996\n",
      " 0.99999934 0.9999994  0.99999994 0.99999964 0.9999997  0.9999992\n",
      " 0.9999992  0.999999   1.0000001  1.0000013  1.0000013  1.0000001\n",
      " 1.0000015  0.99999946 1.0000002  1.0000004  0.9999999  1.0000001\n",
      " 0.99999946 0.9999994  0.99999934 1.0000005  1.0000021  0.9999968\n",
      " 1.0000004  1.0000006  1.0000002  1.000001   1.0000002  1.0000004\n",
      " 1.0000001  0.9999996  1.0000012  0.9999997  1.         1.0000001\n",
      " 0.9999999  1.0000004  1.0000004  1.0000007  1.0000008  1.000001\n",
      " 1.0000018  0.9999996  0.99999803 0.99999934 0.99999905 0.99999934\n",
      " 0.9999997  1.0000005  1.0000001  0.9999993  0.9999996  1.0000008\n",
      " 0.99999875 1.0000004  0.9999989  1.0000002  1.0000033  0.9999999\n",
      " 0.99999005 0.99999964 0.9999982  0.9999996  1.0000017  1.0000006\n",
      " 0.9999974  0.99999905 1.         0.9999999  1.0000011  1.0000006\n",
      " 0.99999696 1.         0.99999845 1.0000005  1.0000006  1.0000002\n",
      " 0.99999946 1.000002   0.9999983  1.0000004  0.9999978  1.0000007\n",
      " 0.99999934 0.9999997  0.9999993  0.9999998  0.9999998  0.99999934\n",
      " 0.9999994  0.99999857 1.         1.0000002  1.         1.0000014\n",
      " 0.99999994 1.0000006  1.0000012  1.0000007  1.0000006  1.0000013\n",
      " 0.99999875 1.0000008  1.0000013  0.9999985  1.0000005  1.0000007\n",
      " 0.9999979  1.0000024  1.0000007  1.0000002  0.9999984  0.9999999\n",
      " 1.0000013  1.0000005  1.0000021  1.0000001  1.0000005  0.9999996\n",
      " 0.99999964 1.0000014  0.9999982  1.0000006  0.99998987 1.000001\n",
      " 0.9999999  0.9999991  0.9999983  1.000001   0.99999976 1.0000027\n",
      " 0.99999815 0.99999833 0.9999991  1.0000012  1.0000039  0.99999833\n",
      " 1.0000062  1.0000001  1.0000005  0.9999986  0.99999994 0.9999987\n",
      " 0.9999997  0.9999984  1.0000012  0.99999934 0.99999917 0.9999997\n",
      " 0.99999887 0.9999989  1.0000002  1.0000006  1.0000014  0.99999905\n",
      " 0.99999976 0.99999946 1.         0.9999961  0.9999995  0.9999991\n",
      " 1.0000007  1.0000007  0.99999994 1.0000005  1.0000005  1.0000006\n",
      " 0.99999976 0.9999999  0.99999976 0.999998   1.0000006  0.9999998\n",
      " 0.9999991  1.0000006  0.99999917 0.9999985  1.0000013  1.0000002\n",
      " 1.0000004  1.0000008  1.0000015  0.99999964 0.9999992  1.0000013\n",
      " 1.         1.0000002  1.0000004  1.         1.0000004  0.9999996\n",
      " 0.99999976 1.         1.0000002  1.0000012  0.99999946 1.0000018\n",
      " 0.9999964  1.0000014  0.99999887 1.0000007  0.99999917 0.9999993\n",
      " 0.9999999  1.0000002  0.99999857 1.0000058  1.         1.0000007\n",
      " 0.9999991  1.0000001  0.9999995  1.0000002  1.0000007  1.0000006\n",
      " 0.9999998  1.0000015  0.9999966  1.0000005  0.99999976 0.9999997\n",
      " 0.99999994 1.         1.0000001  0.9999994  1.0000006  1.0000002\n",
      " 0.9999998  0.9999996  0.9999986  1.         0.9999987  0.9999994\n",
      " 0.9999982  0.99999976 1.0000004  0.99999964 1.0000005  0.999999\n",
      " 0.99999887 1.0000004  0.99999994 0.99999905 0.99999976 0.99999994\n",
      " 1.0000073  0.9999997  0.9999991  1.0000002  1.000001   0.99999976\n",
      " 1.0000011  0.9999917  0.99999946 0.9999991  0.9999989  1.0000005\n",
      " 0.9999988  0.9999977  1.0000002  0.99999994 0.99999857 1.0000019\n",
      " 1.0000005  1.0000008  1.0000001  1.0000001  0.9999997  1.0000005\n",
      " 0.99999994 1.0000002  1.0000012  1.0000012  1.0000002  0.9999999\n",
      " 0.99999845 1.000001   1.         1.0000008  0.9999987  1.0000007\n",
      " 0.9999992  0.99999946 0.99999774 0.99999905 0.99999714 1.0000007\n",
      " 1.0000002  0.9999991  1.0000006  1.         1.0000004  1.0000019\n",
      " 0.9999994  1.0000013  1.0000002  0.999998   1.0000008  1.0000013\n",
      " 0.9999989  1.0000011  1.0000074  1.0000007  1.0000006  0.99999976\n",
      " 1.0000001  1.         1.         1.0000001  0.9999989  1.0000001\n",
      " 0.9999985  1.0000006  1.0000004  1.0000025  1.0000015  1.0000004\n",
      " 0.99999946 1.0000008  1.0000014  0.99999887 0.99999976 0.99999917\n",
      " 1.0000011  1.0000012  1.0000005  0.9999962  1.         0.99999964\n",
      " 0.9999992  0.99999994 1.0000002  0.9999997  1.0000013  1.0000006\n",
      " 0.9999996  0.99999756 0.9999989  1.0000004  1.0000005  1.0000004\n",
      " 1.0000012  0.9999986  0.9999994  0.9999983  1.0000008  0.9999995\n",
      " 1.         1.0000007  0.999999   1.0000004  1.0000002  1.0000002\n",
      " 1.0000011  1.0000008  0.9999993  0.99999976 0.99999917 0.99999905\n",
      " 1.0000013  1.0000012  1.000001   0.9999994  1.000004   1.0000004\n",
      " 1.         0.99999964 1.0000004  1.0000008  1.0000031  1.000001\n",
      " 0.999998   1.0000004  0.9999942  1.0000006  1.0000005  0.99999684\n",
      " 1.0000007  1.0000004  1.0000006  1.0000001  0.9999993  1.0000001\n",
      " 1.0000012  1.         0.99999785 0.99999934 1.0000006  0.9999998\n",
      " 1.0000006  0.99999934 1.0000015  1.0000032  0.9999996  1.0000005\n",
      " 0.99999994 1.0000002  1.0000005  1.0000018  1.0000005  0.9999995\n",
      " 1.0000004  0.9999987  1.0000002  1.         1.         0.9999992\n",
      " 0.99999994 1.0000002  1.0000011  1.0000006  1.0000001  1.0000007\n",
      " 1.0000006  1.0000024  0.9999994  1.000002   1.         1.0000002\n",
      " 1.0000012  0.9999976  0.99999917 0.9999998  0.99999887 1.0000004\n",
      " 1.0000006  1.0000001  0.9999991  0.9999997  1.0000005  0.99999905\n",
      " 1.0000011  0.9999981  1.000001   1.0000006  1.0000012  0.99999964\n",
      " 1.0000002  0.99999875 0.9999998  1.0000027  1.0000005  0.99999857\n",
      " 0.99999744 1.0000001  1.0000001  0.9999993  1.0000004  1.0000019\n",
      " 0.9999994  0.99999875 0.9999979  1.0000004  0.9999995  1.0000005\n",
      " 0.99999887 0.9999981  1.000001   0.9999975  0.9999999  1.0000046\n",
      " 0.9999996  1.0000001  0.9999983  0.9999993  0.9999997  0.9999989\n",
      " 1.0000005  0.99999833]\n",
      "Train data shape:  (2080, 500)\n",
      "Train labels shape:  (2080, 2)\n",
      "Test data shape:  (520, 500)\n",
      "Test labels shape:  (520, 2)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy.io\n",
    "import random\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import tensorflow as tf\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "SEED = 42\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Define one-hot encoding function\n",
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def load_madelon(one_hot=True):\n",
    "    madelon_mat_path = \"./data/madelon/madelon.mat\"\n",
    "\n",
    "    # Load the .mat file\n",
    "    mat_data = scipy.io.loadmat(madelon_mat_path)\n",
    "\n",
    "    # Print available keys\n",
    "    print(\"Keys in .mat file:\", mat_data.keys())\n",
    "\n",
    "    # Extract features and labels (adjust keys if necessary)\n",
    "    X = np.array(mat_data[\"X\"], dtype=np.float32)  # Feature matrix\n",
    "    y = np.array(mat_data[\"Y\"], dtype=np.int64).flatten()  # Labels\n",
    "\n",
    "    print(\"Unique labels before preprocessing:\", np.unique(y))\n",
    "\n",
    "    # Convert labels from {-1, +1} to {0, 1}\n",
    "    y = (y + 1) // 2  # Map -1 → 0, +1 → 1\n",
    "\n",
    "    # Ensure correct label range\n",
    "    assert np.array_equal(np.unique(y), [0, 1]), f\"Unexpected label values after transformation: {np.unique(y)}\"\n",
    "\n",
    "    num_classes = np.unique(y).shape[0]\n",
    "\n",
    "    if one_hot:\n",
    "        Y = one_hot_encode(y, num_classes)\n",
    "    else:\n",
    "        Y = y  # Return as integers\n",
    "\n",
    "    # Shuffle dataset\n",
    "    indices = np.arange(X.shape[0])\n",
    "    np.random.shuffle(indices)\n",
    "    X = X[indices]\n",
    "    Y = Y[indices]\n",
    "\n",
    "    # Split into train/test (80/20 split)\n",
    "    split_index = int(X.shape[0] * 0.8)\n",
    "    X_train, Y_train = X[:split_index], Y[:split_index]\n",
    "    X_test, Y_test = X[split_index:], Y[split_index:]\n",
    "\n",
    "    # Standardize using training set statistics\n",
    "    scaler = StandardScaler().fit(X_train)\n",
    "    X_train = scaler.transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    X_train = X_train.astype(np.float32)\n",
    "    X_test = X_test.astype(np.float32)\n",
    "\n",
    "    if one_hot:\n",
    "        Y_train = Y_train.astype(np.float32)\n",
    "        Y_test = Y_test.astype(np.float32)\n",
    "    else:\n",
    "        Y_train = Y_train.astype(np.int64)\n",
    "        Y_test = Y_test.astype(np.int64)\n",
    "\n",
    "    print(\"Unique train labels after processing:\", np.unique(Y_train))\n",
    "    print(\"Unique test labels after processing:\", np.unique(Y_test))\n",
    "\n",
    "    print(\"X_train shape: {}, Y_train shape: {}\".format(X_train.shape, Y_train.shape))\n",
    "    print(\"X_test shape: {}, Y_test shape: {}\".format(X_test.shape, Y_test.shape))\n",
    "\n",
    "    return (X_train, Y_train), (X_test, Y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_madelon(one_hot=True)\n",
    "\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=0)  \n",
    "    stds = np.std(dataset, axis=0)  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-23T23:28:07.963044Z",
     "iopub.status.busy": "2025-01-23T23:28:07.962380Z",
     "iopub.status.idle": "2025-01-23T23:28:12.852352Z",
     "shell.execute_reply": "2025-01-23T23:28:12.851695Z",
     "shell.execute_reply.started": "2025-01-23T23:28:07.963020Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 500)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               150300    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 2)                 202       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 180,602\n",
      "Trainable params: 180,602\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 1.000e-01\n",
      "Epoch 1/100\n",
      "9/9 [==============================] - 1s 2ms/step - loss: 1.3760 - accuracy: 0.5067\n",
      "\n",
      "Epoch 2: Current learning rate = 9.997e-02\n",
      "Epoch 2/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.7484 - accuracy: 0.6173\n",
      "\n",
      "Epoch 3: Current learning rate = 9.988e-02\n",
      "Epoch 3/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.6841 - accuracy: 0.6846\n",
      "\n",
      "Epoch 4: Current learning rate = 9.972e-02\n",
      "Epoch 4/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.5390 - accuracy: 0.7904\n",
      "\n",
      "Epoch 5: Current learning rate = 9.950e-02\n",
      "Epoch 5/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.3436 - accuracy: 0.9125\n",
      "\n",
      "Epoch 6: Current learning rate = 9.922e-02\n",
      "Epoch 6/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.1932 - accuracy: 0.9784\n",
      "\n",
      "Epoch 7: Current learning rate = 9.888e-02\n",
      "Epoch 7/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.1356 - accuracy: 0.9894\n",
      "\n",
      "Epoch 8: Current learning rate = 9.848e-02\n",
      "Epoch 8/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.1048 - accuracy: 0.9986\n",
      "\n",
      "Epoch 9: Current learning rate = 9.801e-02\n",
      "Epoch 9/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0947 - accuracy: 1.0000\n",
      "\n",
      "Epoch 10: Current learning rate = 9.749e-02\n",
      "Epoch 10/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0911 - accuracy: 1.0000\n",
      "\n",
      "Epoch 11: Current learning rate = 9.691e-02\n",
      "Epoch 11/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0891 - accuracy: 1.0000\n",
      "\n",
      "Epoch 12: Current learning rate = 9.627e-02\n",
      "Epoch 12/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0882 - accuracy: 1.0000\n",
      "\n",
      "Epoch 13: Current learning rate = 9.557e-02\n",
      "Epoch 13/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0875 - accuracy: 1.0000\n",
      "\n",
      "Epoch 14: Current learning rate = 9.481e-02\n",
      "Epoch 14/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0870 - accuracy: 1.0000\n",
      "\n",
      "Epoch 15: Current learning rate = 9.400e-02\n",
      "Epoch 15/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0865 - accuracy: 1.0000\n",
      "\n",
      "Epoch 16: Current learning rate = 9.314e-02\n",
      "Epoch 16/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0861 - accuracy: 1.0000\n",
      "\n",
      "Epoch 17: Current learning rate = 9.222e-02\n",
      "Epoch 17/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0857 - accuracy: 1.0000\n",
      "\n",
      "Epoch 18: Current learning rate = 9.124e-02\n",
      "Epoch 18/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0854 - accuracy: 1.0000\n",
      "\n",
      "Epoch 19: Current learning rate = 9.022e-02\n",
      "Epoch 19/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0851 - accuracy: 1.0000\n",
      "\n",
      "Epoch 20: Current learning rate = 8.914e-02\n",
      "Epoch 20/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0848 - accuracy: 1.0000\n",
      "\n",
      "Epoch 21: Current learning rate = 8.802e-02\n",
      "Epoch 21/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0845 - accuracy: 1.0000\n",
      "\n",
      "Epoch 22: Current learning rate = 8.685e-02\n",
      "Epoch 22/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0842 - accuracy: 1.0000\n",
      "\n",
      "Epoch 23: Current learning rate = 8.563e-02\n",
      "Epoch 23/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0839 - accuracy: 1.0000\n",
      "\n",
      "Epoch 24: Current learning rate = 8.437e-02\n",
      "Epoch 24/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0836 - accuracy: 1.0000\n",
      "\n",
      "Epoch 25: Current learning rate = 8.307e-02\n",
      "Epoch 25/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0834 - accuracy: 1.0000\n",
      "\n",
      "Epoch 26: Current learning rate = 8.172e-02\n",
      "Epoch 26/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0831 - accuracy: 1.0000\n",
      "\n",
      "Epoch 27: Current learning rate = 8.033e-02\n",
      "Epoch 27/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0829 - accuracy: 1.0000\n",
      "\n",
      "Epoch 28: Current learning rate = 7.891e-02\n",
      "Epoch 28/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0827 - accuracy: 1.0000\n",
      "\n",
      "Epoch 29: Current learning rate = 7.745e-02\n",
      "Epoch 29/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0824 - accuracy: 1.0000\n",
      "\n",
      "Epoch 30: Current learning rate = 7.596e-02\n",
      "Epoch 30/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0822 - accuracy: 1.0000\n",
      "\n",
      "Epoch 31: Current learning rate = 7.443e-02\n",
      "Epoch 31/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0820 - accuracy: 1.0000\n",
      "\n",
      "Epoch 32: Current learning rate = 7.287e-02\n",
      "Epoch 32/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0818 - accuracy: 1.0000\n",
      "\n",
      "Epoch 33: Current learning rate = 7.129e-02\n",
      "Epoch 33/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0816 - accuracy: 1.0000\n",
      "\n",
      "Epoch 34: Current learning rate = 6.968e-02\n",
      "Epoch 34/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0814 - accuracy: 1.0000\n",
      "\n",
      "Epoch 35: Current learning rate = 6.804e-02\n",
      "Epoch 35/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0812 - accuracy: 1.0000\n",
      "\n",
      "Epoch 36: Current learning rate = 6.638e-02\n",
      "Epoch 36/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0810 - accuracy: 1.0000\n",
      "\n",
      "Epoch 37: Current learning rate = 6.470e-02\n",
      "Epoch 37/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0808 - accuracy: 1.0000\n",
      "\n",
      "Epoch 38: Current learning rate = 6.300e-02\n",
      "Epoch 38/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0806 - accuracy: 1.0000\n",
      "\n",
      "Epoch 39: Current learning rate = 6.129e-02\n",
      "Epoch 39/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0804 - accuracy: 1.0000\n",
      "\n",
      "Epoch 40: Current learning rate = 5.956e-02\n",
      "Epoch 40/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0802 - accuracy: 1.0000\n",
      "\n",
      "Epoch 41: Current learning rate = 5.782e-02\n",
      "Epoch 41/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0801 - accuracy: 1.0000\n",
      "\n",
      "Epoch 42: Current learning rate = 5.607e-02\n",
      "Epoch 42/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0799 - accuracy: 1.0000\n",
      "\n",
      "Epoch 43: Current learning rate = 5.431e-02\n",
      "Epoch 43/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0798 - accuracy: 1.0000\n",
      "\n",
      "Epoch 44: Current learning rate = 5.255e-02\n",
      "Epoch 44/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0796 - accuracy: 1.0000\n",
      "\n",
      "Epoch 45: Current learning rate = 5.079e-02\n",
      "Epoch 45/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0795 - accuracy: 1.0000\n",
      "\n",
      "Epoch 46: Current learning rate = 4.902e-02\n",
      "Epoch 46/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0793 - accuracy: 1.0000\n",
      "\n",
      "Epoch 47: Current learning rate = 4.725e-02\n",
      "Epoch 47/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0792 - accuracy: 1.0000\n",
      "\n",
      "Epoch 48: Current learning rate = 4.549e-02\n",
      "Epoch 48/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0791 - accuracy: 1.0000\n",
      "\n",
      "Epoch 49: Current learning rate = 4.373e-02\n",
      "Epoch 49/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0790 - accuracy: 1.0000\n",
      "\n",
      "Epoch 50: Current learning rate = 4.198e-02\n",
      "Epoch 50/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0788 - accuracy: 1.0000\n",
      "\n",
      "Epoch 51: Current learning rate = 4.025e-02\n",
      "Epoch 51/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0787 - accuracy: 1.0000\n",
      "\n",
      "Epoch 52: Current learning rate = 3.852e-02\n",
      "Epoch 52/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0786 - accuracy: 1.0000\n",
      "\n",
      "Epoch 53: Current learning rate = 3.681e-02\n",
      "Epoch 53/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0785 - accuracy: 1.0000\n",
      "\n",
      "Epoch 54: Current learning rate = 3.511e-02\n",
      "Epoch 54/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0784 - accuracy: 1.0000\n",
      "\n",
      "Epoch 55: Current learning rate = 3.343e-02\n",
      "Epoch 55/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0783 - accuracy: 1.0000\n",
      "\n",
      "Epoch 56: Current learning rate = 3.178e-02\n",
      "Epoch 56/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0782 - accuracy: 1.0000\n",
      "\n",
      "Epoch 57: Current learning rate = 3.014e-02\n",
      "Epoch 57/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0781 - accuracy: 1.0000\n",
      "\n",
      "Epoch 58: Current learning rate = 2.853e-02\n",
      "Epoch 58/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0781 - accuracy: 1.0000\n",
      "\n",
      "Epoch 59: Current learning rate = 2.695e-02\n",
      "Epoch 59/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0780 - accuracy: 1.0000\n",
      "\n",
      "Epoch 60: Current learning rate = 2.540e-02\n",
      "Epoch 60/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0779 - accuracy: 1.0000\n",
      "\n",
      "Epoch 61: Current learning rate = 2.388e-02\n",
      "Epoch 61/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0778 - accuracy: 1.0000\n",
      "\n",
      "Epoch 62: Current learning rate = 2.238e-02\n",
      "Epoch 62/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0778 - accuracy: 1.0000\n",
      "\n",
      "Epoch 63: Current learning rate = 2.093e-02\n",
      "Epoch 63/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0777 - accuracy: 1.0000\n",
      "\n",
      "Epoch 64: Current learning rate = 1.951e-02\n",
      "Epoch 64/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0777 - accuracy: 1.0000\n",
      "\n",
      "Epoch 65: Current learning rate = 1.813e-02\n",
      "Epoch 65/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0776 - accuracy: 1.0000\n",
      "\n",
      "Epoch 66: Current learning rate = 1.679e-02\n",
      "Epoch 66/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0776 - accuracy: 1.0000\n",
      "\n",
      "Epoch 67: Current learning rate = 1.549e-02\n",
      "Epoch 67/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0775 - accuracy: 1.0000\n",
      "\n",
      "Epoch 68: Current learning rate = 1.423e-02\n",
      "Epoch 68/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0775 - accuracy: 1.0000\n",
      "\n",
      "Epoch 69: Current learning rate = 1.302e-02\n",
      "Epoch 69/100\n",
      "9/9 [==============================] - 0s 3ms/step - loss: 0.0774 - accuracy: 1.0000\n",
      "\n",
      "Epoch 70: Current learning rate = 1.185e-02\n",
      "Epoch 70/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0774 - accuracy: 1.0000\n",
      "\n",
      "Epoch 71: Current learning rate = 1.073e-02\n",
      "Epoch 71/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0774 - accuracy: 1.0000\n",
      "\n",
      "Epoch 72: Current learning rate = 9.665e-03\n",
      "Epoch 72/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0773 - accuracy: 1.0000\n",
      "\n",
      "Epoch 73: Current learning rate = 8.646e-03\n",
      "Epoch 73/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0773 - accuracy: 1.0000\n",
      "\n",
      "Epoch 74: Current learning rate = 7.679e-03\n",
      "Epoch 74/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0773 - accuracy: 1.0000\n",
      "\n",
      "Epoch 75: Current learning rate = 6.764e-03\n",
      "Epoch 75/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0773 - accuracy: 1.0000\n",
      "\n",
      "Epoch 76: Current learning rate = 5.904e-03\n",
      "Epoch 76/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0773 - accuracy: 1.0000\n",
      "\n",
      "Epoch 77: Current learning rate = 5.099e-03\n",
      "Epoch 77/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 78: Current learning rate = 4.349e-03\n",
      "Epoch 78/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 79: Current learning rate = 3.657e-03\n",
      "Epoch 79/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 80: Current learning rate = 3.023e-03\n",
      "Epoch 80/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 81: Current learning rate = 2.447e-03\n",
      "Epoch 81/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 82: Current learning rate = 1.931e-03\n",
      "Epoch 82/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 83: Current learning rate = 1.475e-03\n",
      "Epoch 83/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 84: Current learning rate = 1.079e-03\n",
      "Epoch 84/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 85: Current learning rate = 7.445e-04\n",
      "Epoch 85/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 86: Current learning rate = 4.715e-04\n",
      "Epoch 86/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 87: Current learning rate = 2.604e-04\n",
      "Epoch 87/100\n",
      "9/9 [==============================] - 0s 3ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 88: Current learning rate = 1.114e-04\n",
      "Epoch 88/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 89: Current learning rate = 2.467e-05\n",
      "Epoch 89/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 90: Current learning rate = 0.000e+00\n",
      "Epoch 90/100\n",
      "9/9 [==============================] - 0s 3ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 91: Current learning rate = 0.000e+00\n",
      "Epoch 91/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 92: Current learning rate = 0.000e+00\n",
      "Epoch 92/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 93: Current learning rate = 0.000e+00\n",
      "Epoch 93/100\n",
      "9/9 [==============================] - 0s 6ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 94: Current learning rate = 0.000e+00\n",
      "Epoch 94/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 95: Current learning rate = 0.000e+00\n",
      "Epoch 95/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 96: Current learning rate = 0.000e+00\n",
      "Epoch 96/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "9/9 [==============================] - 0s 3ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "9/9 [==============================] - 0s 2ms/step - loss: 0.0772 - accuracy: 1.0000\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2196 - accuracy: 0.5269\n",
      "\n",
      "Test loss 2.219609022140503\n",
      "Test accuracy 0.5269230604171753\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on madelon\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_madelon'\n",
    "#DEPTH = DEPTH\n",
    "LA = 1e-4 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "EPOCHS = EPOCHS\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "#early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-23T23:30:53.556050Z",
     "iopub.status.busy": "2025-01-23T23:30:53.555400Z",
     "iopub.status.idle": "2025-01-24T00:00:10.176446Z",
     "shell.execute_reply": "2025-01-24T00:00:10.175737Z",
     "shell.execute_reply.started": "2025-01-23T23:30:53.556032Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED  # Store original seed value\n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_madelon'\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            \n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "            \n",
    "            # Check if results already exist\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            #if os.path.exists(pretrain_csv_file_path):\n",
    "            #    print(f'Results already exist for depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}. Skipping...')\n",
    "            #    continue\n",
    "                \n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            \n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                       init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                                  dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                                  large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                       loss='categorical_crossentropy',\n",
    "                       metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                       callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                                'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "            \n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T00:02:37.020802Z",
     "iopub.status.busy": "2025-01-24T00:02:37.019931Z",
     "iopub.status.idle": "2025-01-24T00:02:37.075793Z",
     "shell.execute_reply": "2025-01-24T00:02:37.074935Z",
     "shell.execute_reply.started": "2025-01-24T00:02:37.020772Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Normalized Training Set Mean and SD: [ 4.04050704e-09 -2.23517418e-08  1.06027489e-08  1.49871298e-08\n",
      " -3.63932195e-09 -4.84001177e-08 -1.45572878e-08  1.28952360e-09\n",
      " -1.24940502e-08  2.34979858e-09 -1.89774880e-08  9.05532094e-09\n",
      "  2.21798047e-08  7.30730010e-09  2.87133908e-08  5.78852788e-09\n",
      " -4.34165344e-08  5.70972380e-09 -5.90601807e-08  5.15809451e-10\n",
      " -2.17786211e-09 -2.39564830e-08 -1.69070873e-08 -4.58497285e-09\n",
      "  2.01380601e-08  1.91422611e-08  5.58793545e-09  2.36412667e-10\n",
      "  1.83398907e-09  1.52450337e-08 -9.72157466e-09  1.78384096e-09\n",
      "  8.10967027e-09 -4.68526906e-09  1.31817961e-08  1.22074901e-08\n",
      "  4.92884578e-09  6.53358612e-09  4.63655354e-08 -4.69959716e-08\n",
      "  1.73082721e-08 -1.71936476e-09 -3.12351247e-08 -1.16916805e-08\n",
      " -4.01185130e-09  9.90067495e-09  1.45143044e-08  2.93438251e-08\n",
      "  1.73082721e-08 -1.71936476e-09 -2.84268307e-08  2.72232747e-09\n",
      " -5.11511011e-09  2.85244397e-08 -2.46442280e-08  7.90907784e-09\n",
      " -1.33537332e-08  1.26659874e-08  1.04881250e-08  2.12341540e-08\n",
      "  1.22074901e-08 -2.14060911e-08  3.06799142e-09  7.73714148e-10\n",
      "  8.02370259e-09 -9.88634730e-10 -2.60770316e-09 -1.04917071e-08\n",
      " -2.01738803e-08  1.74802075e-08  2.17786211e-09 -2.92292013e-09\n",
      "  5.75987213e-09  1.30385160e-08 -3.50177274e-08  6.34732134e-09\n",
      " -2.34979858e-09 -3.94737487e-09 -5.15809440e-09  2.45296032e-08\n",
      " -3.55335383e-09 -2.27099428e-09 -2.75098366e-09 -3.38141737e-09\n",
      " -3.95453892e-09  8.59682395e-11  1.66205254e-08 -2.17786211e-09\n",
      "  9.49411749e-09 -5.53742900e-08  7.89797383e-08  8.19563883e-09\n",
      " -5.64524782e-09  2.08616253e-08 -3.43872952e-09 -1.14624321e-09\n",
      "  6.27066683e-08 -1.96007583e-08 -5.52323520e-09 -5.61659164e-09\n",
      "  4.12647561e-09 -1.92282297e-08 -4.87583200e-08  8.32817282e-09\n",
      "  1.19782415e-08 -9.22725718e-09  1.12045271e-08 -3.89722699e-09\n",
      "  1.29006086e-08  2.12628102e-08 -9.62844293e-09 -2.43863241e-08\n",
      "  6.52642207e-09  2.15189253e-08  5.27271871e-09 -1.37262619e-08\n",
      "  9.85769155e-09 -3.12781090e-08  3.56624916e-08  1.09466223e-08\n",
      " -4.69529873e-08 -1.18922729e-08  2.60197197e-08 -1.54742832e-08\n",
      " -7.62251702e-08  5.20466026e-09 -1.62802349e-09  1.33895526e-08\n",
      " -2.29248638e-08  3.89722699e-09  2.34979858e-09  1.73655845e-08\n",
      "  1.28952360e-08 -3.65365005e-09 -1.29525475e-08 -5.37301492e-09\n",
      " -3.10022963e-09 -2.46442280e-08  1.02588764e-08  5.64524782e-09\n",
      " -1.65381397e-08 -1.01442517e-08 -3.43872952e-09 -8.16698265e-09\n",
      "  1.15197443e-08  4.58497285e-09 -2.96017291e-08 -1.18886909e-08\n",
      "  5.98912075e-09  1.16630243e-08  1.80533299e-09 -2.69940266e-08\n",
      "  1.03018607e-08  2.55039101e-09 -8.64885266e-08 -1.93715088e-08\n",
      "  2.77104277e-08 -1.62479967e-08 -2.27529267e-08 -1.43853516e-08\n",
      " -7.20700388e-09 -4.59213689e-09  3.15216880e-10 -6.73417855e-09\n",
      " -2.00019432e-08 -1.11185585e-08 -6.05538801e-08  1.01478337e-08\n",
      " -5.08358831e-08 -1.29525475e-08  2.95157609e-09 -8.39623127e-09\n",
      "  4.59356961e-08 -8.96075605e-08 -1.44426640e-08  1.03161888e-08\n",
      "  8.36757508e-09 -1.78240818e-08  2.62024025e-09 -1.15770558e-08\n",
      " -2.72232761e-08 -6.76283474e-09 -2.48734775e-08  1.17776482e-08\n",
      "  4.58497268e-10  7.80878118e-09  1.62766529e-08  3.86857080e-09\n",
      " -4.41303616e-09 -1.43853516e-08  2.80829582e-09  1.85977953e-07\n",
      " -6.84880286e-09 -1.02588764e-08  1.79960171e-08 -2.19505569e-08\n",
      " -7.39326822e-09 -2.12054978e-08 -4.19954844e-08  1.30277700e-08\n",
      "  3.05473797e-08 -1.53596584e-08 -1.97153831e-08  1.25728548e-08\n",
      "  5.28704680e-09 -1.38695420e-08 -1.24653941e-08  4.23364924e-07\n",
      "  1.04308127e-08 -4.58497268e-10  2.60770316e-09 -2.55755506e-09\n",
      "  1.71076788e-08  1.29382194e-08 -1.58468119e-08 -1.46002721e-08\n",
      " -2.63492641e-08 -1.36402933e-08  4.36217178e-08  6.97417306e-08\n",
      "  3.49604168e-09 -6.44045395e-09  7.32341920e-09 -7.26789828e-09\n",
      " -9.56038448e-09  8.61401759e-08 -2.85127992e-08 -1.28321929e-07\n",
      " -4.24396518e-08 -1.20928654e-08 -2.63815036e-09  1.04451408e-08\n",
      " -1.27662831e-08  5.15809440e-09  3.63932195e-09  2.87747337e-09\n",
      " -5.73121606e-10 -1.64485897e-07 -4.29841185e-09 -1.33107489e-08\n",
      "  1.42563996e-08  4.90018959e-09 -2.45009470e-08 -3.30691137e-08\n",
      " -1.27232989e-08 -4.40587211e-09  1.14337757e-08 -2.10335624e-08\n",
      " -4.22390620e-08  1.06027498e-09  5.33003064e-09 -1.50730983e-08\n",
      " -9.88634774e-09  1.57894995e-08 -3.20948090e-09 -9.59978674e-09\n",
      " -9.16994536e-10 -1.77008602e-07  2.11209638e-07 -9.97231542e-09\n",
      "  2.51600376e-08 -1.36402933e-08  1.28952360e-08 -3.09485659e-09\n",
      " -9.15561760e-09 -2.00592565e-09 -1.00869402e-08 -1.85977953e-08\n",
      " -3.67370951e-08 -1.66778378e-08  2.29248634e-10  9.45650580e-09\n",
      " -4.47034854e-09 -8.88338458e-10 -1.66061973e-08  2.86560784e-08\n",
      "  2.98166505e-08 -2.01165680e-08  2.73020788e-08  8.98726249e-09\n",
      "  1.80533299e-09 -1.22361454e-08  1.40414791e-08  7.86609391e-08\n",
      " -3.07649883e-09  1.82825790e-08 -3.09485659e-09  2.33260486e-08\n",
      " -3.31264260e-08 -6.47627374e-09  2.95157609e-09 -1.35973099e-08\n",
      "  1.77882615e-08  1.18636168e-08  1.16916805e-08  3.06620040e-09\n",
      " -1.23507702e-08  2.37487257e-08  6.67686662e-09 -6.59089805e-09\n",
      " -7.96639021e-09 -1.81249700e-08  5.44465495e-09  3.63932195e-09\n",
      "  8.48219983e-09  1.11758709e-08  1.24474848e-08  1.31244846e-08\n",
      " -1.28235955e-09  1.00163746e-07  1.55258633e-07 -1.42707277e-08\n",
      "  1.06815534e-08  1.63626215e-08  1.31137385e-08 -3.15216875e-09\n",
      " -9.62844293e-09  5.50196733e-09 -4.01185113e-10 -3.14070618e-08\n",
      " -2.15063878e-08 -1.59327804e-08  2.08043129e-08  3.19228732e-08\n",
      "  8.83681839e-09 -5.81718407e-09  1.54742832e-08  1.28665798e-08\n",
      " -1.44082762e-07 -1.67351502e-08  6.41896181e-09  2.30394885e-08\n",
      " -2.37272332e-07  3.23813687e-09 -1.18206323e-09 -4.30987441e-08\n",
      " -1.59327804e-08  6.74850664e-09  7.96639021e-09  1.09752785e-08\n",
      " -2.89426394e-09 -3.58774095e-08  8.59682381e-10 -2.91718880e-08\n",
      "  1.42563996e-08  2.28675514e-08  1.28379236e-08 -3.49604168e-09\n",
      " -1.06314051e-08  1.03735003e-08  1.71452896e-08 -5.35868683e-09\n",
      "  3.09485659e-09 -1.06185098e-07  1.03735003e-08  6.47627374e-09\n",
      " -6.23448804e-09  1.24940499e-07  6.82014667e-09 -1.38122305e-08\n",
      "  1.73655845e-08 -3.81698975e-08 -3.09485659e-09 -1.88413729e-08\n",
      "  1.34339700e-07 -8.25295121e-09  1.96580707e-08 -8.31026303e-10\n",
      " -7.16401981e-12  1.30671722e-08  3.75394649e-09  1.34110447e-08\n",
      "  1.51304107e-08  1.53596584e-08 -2.17356355e-08  1.79387047e-08\n",
      "  1.04129025e-08 -4.87153373e-10 -2.40137936e-08  1.58969602e-08\n",
      "  7.79445397e-09  2.16353402e-09 -8.02370226e-10 -3.03754444e-09\n",
      "  2.57904720e-08  2.24950227e-08 -1.38122305e-08 -2.83551902e-08\n",
      " -6.01777683e-10 -1.62694889e-08  1.80246733e-08  1.14624321e-09\n",
      " -1.03161890e-09 -5.18531742e-08 -7.57666712e-08 -4.25542757e-09\n",
      " -4.93171122e-08 -3.49102685e-08 -5.91748028e-09  1.70503667e-09\n",
      " -1.17825738e-09 -1.36617855e-08 -6.00380687e-08 -1.62193405e-08\n",
      " -1.46031383e-07  6.30433750e-09 -1.28952360e-08  2.91432318e-08\n",
      " -6.01777650e-09 -1.80390014e-08  5.47331114e-09 -2.92005442e-08\n",
      " -1.34110447e-08 -3.03754444e-09  2.16066844e-08 -7.73714159e-09\n",
      " -2.80829582e-09 -4.58497285e-09  9.06248498e-09  2.34406734e-08\n",
      " -3.55335388e-08  1.84545144e-08  1.23364421e-08  6.38779838e-08\n",
      " -7.45058071e-10  8.36757508e-09 -1.85082456e-08  2.47588527e-08\n",
      "  1.20642092e-08  1.49011612e-08 -1.66626144e-08 -4.20527968e-08\n",
      "  2.35552964e-08  2.17213074e-08 -2.13487787e-08 -8.95502450e-10\n",
      "  9.39919431e-09 -1.61047158e-08 -1.62193405e-08 -1.97297112e-08\n",
      "  1.54742830e-09  1.97153831e-08 -1.24080826e-08 -4.82138551e-09\n",
      "  2.97450100e-08  1.06027493e-07 -6.06362605e-08  1.97153831e-08\n",
      "  3.69448507e-08 -2.34979858e-09 -1.99804511e-08  2.59624073e-08\n",
      "  1.37549181e-08  6.87745916e-10 -3.12351256e-09 -1.37549183e-09\n",
      " -1.73655845e-08  2.57904720e-08 -7.02073955e-09 -1.07173737e-08\n",
      " -8.53951132e-09 -3.16363113e-08  3.78260268e-09  3.23813687e-08\n",
      "  3.26679306e-09 -6.70552236e-09 -8.20996693e-09 -2.59337507e-09\n",
      "  9.22725718e-09 -2.82262391e-09  5.47331114e-09  2.59194231e-08\n",
      "  5.15809440e-09 -2.18588568e-07 -2.62489692e-08 -5.55927926e-09\n",
      "  2.28102390e-08 -7.53654916e-09 -1.18277965e-08  4.06916323e-09\n",
      "  7.60818875e-09  2.15708642e-08 -2.80829582e-09  3.32410521e-09\n",
      "  1.43280399e-09  7.33595629e-09 -2.75098366e-09 -8.79741613e-09\n",
      " -4.32420251e-08 -4.58497268e-10  1.49011614e-09  1.53023461e-08\n",
      "  1.01442517e-08  1.33537332e-08 -3.36708939e-10 -2.86560803e-10\n",
      "  2.00664196e-08 -1.17489929e-09 -1.20928654e-08  1.06027498e-09] [0.9999985  1.0000005  0.9999992  0.9999989  1.0000007  0.9999992\n",
      " 0.9999997  1.0000013  1.0000004  0.9999999  0.9999982  1.0000004\n",
      " 1.0000011  1.0000013  0.9999998  1.0000011  1.0000006  0.99999946\n",
      " 1.000001   0.999998   0.99999994 0.99999905 0.99999833 0.99999845\n",
      " 0.99999934 1.         0.9999984  1.0000035  1.         1.0000007\n",
      " 1.0000002  1.0000002  1.         1.0000006  0.9999991  1.0000007\n",
      " 0.9999999  1.000001   0.9999988  0.99999946 0.9999997  0.9999996\n",
      " 0.9999994  0.999999   1.000001   0.999999   0.99999976 1.0000013\n",
      " 1.0000002  0.9999997  1.         1.000001   0.9999986  1.0000031\n",
      " 0.99999905 0.9999985  0.9999994  1.0000004  1.0000011  0.9999986\n",
      " 1.0000008  1.0000015  0.99999934 1.0000005  1.         1.0000006\n",
      " 1.0000001  1.0000012  0.9999998  1.0000008  1.0000001  0.99999934\n",
      " 0.99999946 1.0000008  0.9999981  1.0000006  0.99999857 1.0000004\n",
      " 1.0000002  1.0000001  0.9999992  1.0000008  1.0000002  1.0000006\n",
      " 0.99999964 0.99999905 0.99999607 0.999998   0.9999985  1.0000004\n",
      " 0.9999985  1.0000011  1.0000006  1.000001   1.0000007  1.0000012\n",
      " 1.0000032  1.0000004  0.9999996  1.0000005  1.0000006  1.0000005\n",
      " 0.9999969  1.0000029  1.0000004  0.9999993  0.99999917 1.\n",
      " 0.99999946 0.9999979  0.9999982  0.9999994  0.9999983  1.0000004\n",
      " 1.0000014  1.         0.99999964 0.999998   0.99999934 1.\n",
      " 1.0000027  1.0000001  0.99999833 1.0000001  1.0000001  1.0000006\n",
      " 1.0000019  0.99999857 1.0000008  1.0000008  1.         1.\n",
      " 0.9999993  0.9999981  0.9999991  0.9999994  0.9999991  1.0000013\n",
      " 0.9999985  1.0000002  0.9999993  1.0000002  1.0000005  1.0000014\n",
      " 1.0000001  1.0000006  0.9999988  0.99999833 0.99999887 0.9999995\n",
      " 0.99999976 1.0000002  0.9999995  1.0000002  1.000003   0.999999\n",
      " 1.0000004  0.9999995  0.9999996  1.0000006  1.0000006  1.0000007\n",
      " 0.99999917 0.99999994 1.0000013  0.99999815 1.000003   0.9999987\n",
      " 0.999997   1.0000005  0.99999946 1.0000018  0.99999976 1.0000057\n",
      " 0.99999774 0.99999905 0.9999998  1.0000011  1.0000018  0.99999964\n",
      " 0.99999976 1.0000011  1.0000005  0.9999991  0.99999624 0.99999964\n",
      " 0.99999976 0.99999976 1.         1.0000013  1.0000002  0.99999994\n",
      " 0.99999964 1.0000005  0.9999989  1.0000011  0.99999857 1.\n",
      " 1.0000002  0.9999998  0.99999994 1.0000013  1.0000011  1.0000013\n",
      " 1.0000005  0.999999   0.9999994  0.9999999  0.9999987  1.0000013\n",
      " 0.99999654 1.         0.9999989  1.0000001  1.0000014  1.\n",
      " 1.0000001  1.0000001  0.99999994 0.99999946 1.0000002  0.9999996\n",
      " 1.0000011  0.99999994 0.99999905 1.0000001  0.9999985  0.999999\n",
      " 1.0000019  1.         1.0000015  0.99999934 1.0000013  0.99999994\n",
      " 1.000001   1.0000011  0.99999857 0.99999654 0.9999998  0.99999815\n",
      " 1.0000008  0.9999995  0.99999905 0.99999934 1.0000013  1.0000002\n",
      " 0.9999996  0.99999917 1.0000002  0.99999875 1.0000005  1.0000019\n",
      " 1.0000029  0.99999774 1.0000021  1.         0.9999994  0.99999905\n",
      " 0.99999994 1.0000002  0.9999998  0.99999803 1.0000001  0.999999\n",
      " 1.         0.99999917 0.99999964 1.0000005  1.0000005  1.0000011\n",
      " 1.0000006  0.9999999  0.9999998  0.99999964 0.9999997  0.9999999\n",
      " 1.0000043  0.9999999  0.9999997  0.99999994 1.0000044  1.0000004\n",
      " 1.000001   0.99999285 0.99999917 1.0000002  1.0000005  0.99999964\n",
      " 0.99999976 0.99999976 0.99999964 1.0000004  0.9999992  0.9999997\n",
      " 1.0000012  0.9999996  1.0000013  0.9999991  1.0000002  1.0000004\n",
      " 0.99999964 0.99999964 0.9999994  1.         1.0000001  0.9999996\n",
      " 0.9999999  1.0000018  0.9999993  1.0000002  1.0000002  0.9999997\n",
      " 1.0000004  1.0000002  1.0000004  0.9999996  0.9999986  0.9999983\n",
      " 1.0000002  1.0000001  0.999999   1.0000002  0.9999998  0.99999803\n",
      " 0.9999991  0.999998   1.         1.0000005  0.9999987  1.0000002\n",
      " 1.0000002  1.0000012  0.9999959  1.0000002  0.99999964 0.99999833\n",
      " 1.         0.99999964 1.         0.9999999  1.0000001  0.99999964\n",
      " 0.99999934 0.99999946 0.9999989  1.0000013  0.99999905 0.9999986\n",
      " 1.0000002  0.99999905 1.0000015  0.99999994 0.99999976 1.0000006\n",
      " 0.9999988  0.9999998  1.0000008  1.000002   0.9999994  0.9999985\n",
      " 0.99999905 0.9999994  0.9999997  1.0000005  1.0000002  1.0000008\n",
      " 1.0000029  1.0000025  0.99999833 0.999998   1.0000007  0.9999992\n",
      " 0.9999998  1.0000006  1.0000007  1.0000007  0.9999995  1.0000013\n",
      " 0.9999996  0.9999998  1.0000014  0.99999976 1.0000017  1.0000013\n",
      " 1.0000012  1.0000013  1.         0.9999987  0.9999996  0.99999917\n",
      " 1.0000008  1.0000011  0.9999997  0.9999999  1.0000018  1.0000013\n",
      " 0.99999905 1.0000011  1.0000005  1.0000015  1.0000015  1.0000017\n",
      " 0.9999997  0.9999986  0.99999875 1.0000002  0.9999993  0.99999523\n",
      " 1.0000006  0.99999976 1.000001   1.0000002  0.9999994  1.0000011\n",
      " 0.99999994 0.99999994 1.000002   0.9999992  0.99999964 1.0000002\n",
      " 0.9999996  1.0000007  1.0000007  0.99999267 0.9999999  0.99999887\n",
      " 0.99999917 0.9999996  0.99999887 0.99999815 0.9999994  0.99999994\n",
      " 0.99999887 0.99999994 0.9999996  0.99999887 0.99999934 1.0000013\n",
      " 0.9999994  0.99999774 1.0000001  0.99999917 0.9999998  1.\n",
      " 1.0000015  0.9999976  1.0000002  0.9999975  1.0000004  1.0000006\n",
      " 0.9999985  1.0000005  0.99999994 0.9999996  0.9999993  1.0000002\n",
      " 0.9999974  0.9999983  1.0000007  0.9999993  1.0000011  0.99999905\n",
      " 0.9999999  1.0000001  1.000001   1.0000011  1.0000002  0.9999991\n",
      " 1.0000001  1.000001   1.0000019  1.0000007  1.0000005  0.9999997\n",
      " 1.0000017  1.0000004  0.9999994  1.0000013  0.9999991  1.0000001\n",
      " 0.9999999  0.9999985  1.0000001  0.99999976 0.9999993  0.9999999\n",
      " 1.000001   0.99999934 1.         0.99999917 0.9999994  1.0000035\n",
      " 1.         0.99999964 1.0000007  0.9999995  1.0000004  1.0000021\n",
      " 1.0000006  0.99999946]\n",
      "Train data shape:  (2080, 500)\n",
      "Train labels shape:  (2080,)\n",
      "Test data shape:  (520, 500)\n",
      "Test labels shape:  (520,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_madelon(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-24T00:04:13.657432Z",
     "iopub.status.busy": "2025-01-24T00:04:13.657162Z",
     "iopub.status.idle": "2025-01-24T00:50:05.967014Z",
     "shell.execute_reply": "2025-01-24T00:50:05.966473Z",
     "shell.execute_reply.started": "2025-01-24T00:04:13.657413Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Starting repetition 1/5\n",
      "Loading dataset: madelon for repetition 1/5\n",
      "Loading dataset: madelon with one_hot = False for repetition 1\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running LassoNet on madelon (repetition 1)\n",
      "Repetition 1: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006000000000000005, 0.010000000000000009, 0.0, 0.008000000000000007, 0.008000000000000007, 0.016000000000000014, 0.016000000000000014, 0.02400000000000002, 0.02400000000000002, 0.026000000000000023, 0.03200000000000003, 0.038000000000000034, 0.04600000000000004, 0.052000000000000046, 0.05800000000000005, 0.06599999999999995, 0.07399999999999995, 0.07799999999999996, 0.09799999999999998, 0.12, 0.126, 0.132, 0.14, 0.14800000000000002, 0.15400000000000003, 0.16400000000000003, 0.16200000000000003, 0.17600000000000005, 0.18000000000000005, 0.18600000000000005, 0.18999999999999995, 0.19799999999999995, 0.20199999999999996, 0.20799999999999996, 0.21199999999999997, 0.21999999999999997, 0.22799999999999998, 0.22999999999999998, 0.23399999999999999, 0.242, 0.256, 0.27, 0.278, 0.28400000000000003, 0.28600000000000003, 0.28800000000000003, 0.30600000000000005, 0.31999999999999995, 0.32599999999999996, 0.33199999999999996, 0.348, 0.354, 0.362, 0.368, 0.372, 0.374, 0.384, 0.394, 0.406, 0.41000000000000003, 0.41000000000000003, 0.43400000000000005, 0.43600000000000005, 0.43999999999999995, 0.44999999999999996, 0.46599999999999997, 0.474, 0.488, 0.498, 0.504, 0.516, 0.532, 0.54, 0.55, 0.562, 0.5740000000000001, 0.5800000000000001, 0.5900000000000001, 0.604, 0.612, 0.616, 0.624, 0.638, 0.648, 0.652, 0.656, 0.6619999999999999, 0.6739999999999999, 0.6819999999999999, 0.6839999999999999, 0.694, 0.702, 0.706, 0.718, 0.724, 0.734, 0.742, 0.75, 0.76, 0.77, 0.776, 0.782, 0.792, 0.796, 0.808, 0.8160000000000001, 0.834, 0.842, 0.84, 0.844, 0.85, 0.858, 0.866, 0.872, 0.872, 0.874, 0.876, 0.882, 0.884, 0.888, 0.898, 0.902, 0.912, 0.918, 0.924, 0.926, 0.9359999999999999, 0.946, 0.95, 0.948, 0.952, 0.954, 0.958, 0.958, 0.958, 0.96, 0.96, 0.964, 0.964, 0.964, 0.964, 0.968, 0.97, 0.97, 0.972, 0.972, 0.976, 0.972, 0.974, 0.974, 0.974, 0.98, 0.98, 0.982, 0.982, 0.984, 0.984, 0.984, 0.984, 0.984, 0.984, 0.984, 0.986, 0.986, 0.986, 0.986, 0.994, 0.992, 0.99, 0.992, 0.992, 0.994, 0.994, 0.994, 0.994, 0.994, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.998, 0.996, 0.998, 0.996, 0.998, 0.998, 0.998, 0.996, 0.998, 1.0] and test accuracy = [0.5403846153846154, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5326923076923077, 0.5365384615384615, 0.5365384615384615, 0.5346153846153846, 0.5326923076923077, 0.5326923076923077, 0.5346153846153846, 0.5365384615384615, 0.5384615384615384, 0.5403846153846154, 0.5403846153846154, 0.5403846153846154, 0.5403846153846154, 0.5403846153846154, 0.5423076923076923, 0.5442307692307692, 0.5442307692307692, 0.5442307692307692, 0.5423076923076923, 0.5423076923076923, 0.5442307692307692, 0.5423076923076923, 0.5423076923076923, 0.5423076923076923, 0.5403846153846154, 0.5403846153846154, 0.5423076923076923, 0.5423076923076923, 0.5423076923076923, 0.5403846153846154, 0.5384615384615384, 0.5423076923076923, 0.5423076923076923, 0.5423076923076923, 0.5403846153846154, 0.5384615384615384, 0.5365384615384615, 0.5365384615384615, 0.5365384615384615, 0.5365384615384615, 0.5326923076923077, 0.5288461538461539, 0.5384615384615384, 0.5403846153846154, 0.5461538461538461, 0.5403846153846154, 0.5403846153846154, 0.5346153846153846, 0.5307692307692308, 0.5346153846153846, 0.5423076923076923, 0.5480769230769231, 0.5423076923076923, 0.5423076923076923, 0.5442307692307692, 0.55, 0.5615384615384615, 0.5769230769230769, 0.573076923076923, 0.5807692307692308, 0.575, 0.5884615384615385, 0.5846153846153846, 0.5884615384615385, 0.5903846153846154, 0.5923076923076923, 0.5942307692307692, 0.6, 0.6, 0.6019230769230769, 0.6038461538461538, 0.6076923076923076, 0.6115384615384616, 0.6115384615384616, 0.6096153846153847, 0.6076923076923076, 0.6096153846153847, 0.6134615384615385, 0.6115384615384616, 0.6134615384615385, 0.6211538461538462, 0.6173076923076923, 0.6153846153846154, 0.6173076923076923, 0.625, 0.6307692307692307, 0.6288461538461538, 0.6307692307692307, 0.6326923076923077, 0.6442307692307693, 0.6442307692307693, 0.6480769230769231, 0.6519230769230769, 0.6576923076923077, 0.6596153846153846, 0.6615384615384615, 0.6653846153846154, 0.6653846153846154, 0.6673076923076923, 0.6653846153846154, 0.6692307692307692, 0.6711538461538461, 0.676923076923077, 0.6788461538461539, 0.6865384615384615, 0.6865384615384615, 0.6865384615384615, 0.6846153846153846, 0.6846153846153846, 0.6884615384615385, 0.6884615384615385, 0.6903846153846154, 0.6903846153846154, 0.6923076923076923, 0.6923076923076923, 0.6961538461538461, 0.7, 0.7038461538461539, 0.7096153846153846, 0.7134615384615385, 0.7115384615384616, 0.7115384615384616, 0.7096153846153846, 0.7038461538461539, 0.7019230769230769, 0.7038461538461539, 0.7057692307692308, 0.7115384615384616, 0.7096153846153846, 0.7134615384615385, 0.7153846153846154, 0.7192307692307692, 0.7153846153846154, 0.7115384615384616, 0.7134615384615385, 0.7096153846153846, 0.7096153846153846, 0.7134615384615385, 0.7115384615384616, 0.7096153846153846, 0.7173076923076923, 0.7134615384615385, 0.7134615384615385, 0.7115384615384616, 0.7096153846153846, 0.7115384615384616, 0.7096153846153846, 0.7096153846153846, 0.7096153846153846, 0.7076923076923077, 0.7076923076923077, 0.7057692307692308, 0.7076923076923077, 0.7115384615384616, 0.7134615384615385, 0.7134615384615385, 0.7153846153846154, 0.7096153846153846, 0.7115384615384616, 0.7134615384615385, 0.7134615384615385, 0.7115384615384616, 0.7134615384615385, 0.7096153846153846, 0.7076923076923077, 0.7057692307692308, 0.698076923076923, 0.6961538461538461, 0.6884615384615385, 0.6903846153846154, 0.6923076923076923, 0.6961538461538461, 0.6961538461538461, 0.6923076923076923, 0.6826923076923077, 0.676923076923077, 0.675, 0.6788461538461539, 0.6730769230769231, 0.6673076923076923, 0.6634615384615384, 0.6596153846153846, 0.6576923076923077, 0.6557692307692308, 0.6576923076923077, 0.6653846153846154, 0.6634615384615384, 0.6615384615384615, 0.6596153846153846, 0.6596153846153846, 0.6557692307692308, 0.6538461538461539, 0.6557692307692308, 0.6519230769230769, 0.6538461538461539, 0.65, 0.65, 0.6461538461538462, 0.6442307692307693, 0.6346153846153846, 0.6211538461538462, 0.625, 0.6192307692307693, 0.6153846153846154, 0.6153846153846154, 0.6115384615384616, 0.6134615384615385, 0.6153846153846154, 0.6134615384615385, 0.6153846153846154, 0.6134615384615385, 0.6173076923076923, 0.6134615384615385, 0.6192307692307693, 0.6076923076923076, 0.6134615384615385, 0.6134615384615385, 0.6115384615384616, 0.6134615384615385, 0.6115384615384616, 0.6076923076923076, 0.6192307692307693, 0.6134615384615385, 0.6153846153846154, 0.6173076923076923, 0.6153846153846154, 0.6134615384615385, 0.6115384615384616, 0.6153846153846154, 0.6096153846153847, 0.6057692307692307, 0.6076923076923076, 0.6076923076923076, 0.6076923076923076, 0.6076923076923076, 0.6115384615384616, 0.6115384615384616, 0.6115384615384616, 0.6115384615384616, 0.6115384615384616, 0.6115384615384616, 0.6096153846153847, 0.6096153846153847, 0.6076923076923076, 0.6076923076923076, 0.6076923076923076, 0.6076923076923076, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6038461538461538, 0.6038461538461538, 0.6057692307692307, 0.6057692307692307, 0.6038461538461538, 0.6019230769230769, 0.6019230769230769, 0.6019230769230769, 0.6019230769230769, 0.6019230769230769, 0.6019230769230769, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5942307692307692, 0.5884615384615385, 0.5865384615384616, 0.5865384615384616, 0.5865384615384616, 0.5846153846153846, 0.5865384615384616, 0.5826923076923077, 0.5769230769230769, 0.575, 0.575, 0.575, 0.5788461538461539, 0.573076923076923, 0.5384615384615384, 0.55, 0.5173076923076924, 0.5288461538461539, 0.5057692307692307, 0.4826923076923077, 0.47692307692307695, 0.46153846153846156, 0.4519230769230769, 0.4519230769230769]\n",
      "Results successfully saved to results/input_sparsity/madelon/LassoNet/rep_1/madelon_LassoNet_rep1_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 1\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_dnn on madelon (repetition 1)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6578 - accuracy: 0.6365\n",
      "test acc for vanilla model is [0.6578324437141418, 0.6365384459495544]\n",
      "0.6578324437141418 0.6365384459495544\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 0.6445 - accuracy: 0.6519\n",
      "test acc for vanilla model is [0.6444705724716187, 0.6519230604171753]\n",
      "0.6444705724716187 0.6519230604171753\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 1.4007 - accuracy: 0.5846\n",
      "test acc for vanilla model is [1.4006986618041992, 0.5846154093742371]\n",
      "1.4006986618041992 0.5846154093742371\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1487 - accuracy: 0.5635\n",
      "test acc for vanilla model is [2.1486563682556152, 0.5634615421295166]\n",
      "2.1486563682556152 0.5634615421295166\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 73\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4287 - accuracy: 0.5250\n",
      "test acc for vanilla model is [2.4286975860595703, 0.5249999761581421]\n",
      "2.4286975860595703 0.5249999761581421\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 127\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.0957 - accuracy: 0.5769\n",
      "test acc for vanilla model is [2.095668315887451, 0.5769230723381042]\n",
      "2.095668315887451 0.5769230723381042\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 180\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2710 - accuracy: 0.5212\n",
      "test acc for vanilla model is [2.2710049152374268, 0.5211538672447205]\n",
      "2.2710049152374268 0.5211538672447205\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 3ms/step - loss: 2.3342 - accuracy: 0.5269\n",
      "test acc for vanilla model is [2.3341991901397705, 0.5269230604171753]\n",
      "2.3341991901397705 0.5269230604171753\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2460 - accuracy: 0.5462\n",
      "test acc for vanilla model is [2.2460267543792725, 0.5461538434028625]\n",
      "2.2460267543792725 0.5461538434028625\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.2505 - accuracy: 0.5365\n",
      "test acc for vanilla model is [2.250527858734131, 0.5365384817123413]\n",
      "2.250527858734131 0.5365384817123413\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3751 - accuracy: 0.5712\n",
      "test acc for vanilla model is [2.3751165866851807, 0.5711538195610046]\n",
      "2.3751165866851807 0.5711538195610046\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4847 - accuracy: 0.5519\n",
      "test acc for vanilla model is [2.484682559967041, 0.5519230961799622]\n",
      "2.484682559967041 0.5519230961799622\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 194\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2983 - accuracy: 0.5192\n",
      "test acc for vanilla model is [2.2983462810516357, 0.5192307829856873]\n",
      "2.2983462810516357 0.5192307829856873\n",
      "Repetition 1: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6365384459495544, 0.6519230604171753, 0.5846154093742371, 0.5634615421295166, 0.5249999761581421, 0.5769230723381042, 0.5211538672447205, 0.5269230604171753, 0.5461538434028625, 0.5365384817123413, 0.5711538195610046, 0.5519230961799622, 0.5192307829856873]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_dnn/rep_1/madelon_HSIC_dnn_rep1_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 1\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_svm on madelon (repetition 1)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 1: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6057692307692307, 0.6423076923076924, 0.6134615384615385, 0.6307692307692307, 0.5884615384615385, 0.6, 0.5769230769230769, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_svm/rep_1/madelon_HSIC_svm_rep1_res.csv\n",
      "\n",
      "Starting repetition 2/5\n",
      "Loading dataset: madelon for repetition 2/5\n",
      "Loading dataset: madelon with one_hot = False for repetition 2\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running LassoNet on madelon (repetition 2)\n",
      "Repetition 2: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0020000000000000018, 0.006000000000000005, 0.008000000000000007, 0.008000000000000007, 0.006000000000000005, 0.010000000000000009, 0.016000000000000014, 0.020000000000000018, 0.03200000000000003, 0.03400000000000003, 0.04800000000000004, 0.04600000000000004, 0.05400000000000005, 0.05800000000000005, 0.06799999999999995, 0.07599999999999996, 0.08599999999999997, 0.08799999999999997, 0.09799999999999998, 0.10799999999999998, 0.11599999999999999, 0.138, 0.14600000000000002, 0.15600000000000003, 0.16800000000000004, 0.17800000000000005, 0.18799999999999994, 0.19199999999999995, 0.19799999999999995, 0.20399999999999996, 0.21199999999999997, 0.22199999999999998, 0.23199999999999998, 0.22799999999999998, 0.24, 0.25, 0.256, 0.262, 0.264, 0.27, 0.28400000000000003, 0.29600000000000004, 0.30200000000000005, 0.31200000000000006, 0.32199999999999995, 0.32799999999999996, 0.33999999999999997, 0.346, 0.352, 0.37, 0.376, 0.384, 0.386, 0.394, 0.404, 0.41000000000000003, 0.42000000000000004, 0.43000000000000005, 0.43999999999999995, 0.45199999999999996, 0.45799999999999996, 0.46599999999999997, 0.478, 0.486, 0.504, 0.51, 0.516, 0.532, 0.538, 0.544, 0.55, 0.556, 0.562, 0.5640000000000001, 0.5660000000000001, 0.5780000000000001, 0.598, 0.604, 0.61, 0.614, 0.62, 0.624, 0.628, 0.644, 0.652, 0.656, 0.6699999999999999, 0.6759999999999999, 0.688, 0.688, 0.7, 0.714, 0.72, 0.726, 0.738, 0.746, 0.75, 0.766, 0.772, 0.784, 0.8, 0.804, 0.812, 0.8180000000000001, 0.8260000000000001, 0.842, 0.848, 0.854, 0.86, 0.862, 0.864, 0.876, 0.874, 0.88, 0.884, 0.882, 0.89, 0.898, 0.902, 0.906, 0.912, 0.916, 0.922, 0.922, 0.926, 0.9359999999999999, 0.938, 0.938, 0.94, 0.946, 0.948, 0.95, 0.952, 0.952, 0.954, 0.958, 0.962, 0.964, 0.968, 0.968, 0.968, 0.972, 0.974, 0.976, 0.976, 0.976, 0.98, 0.98, 0.98, 0.98, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.982, 0.992, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 1.0] and test accuracy = [0.5615384615384615, 0.5653846153846154, 0.5673076923076923, 0.5653846153846154, 0.5634615384615385, 0.5653846153846154, 0.5653846153846154, 0.5653846153846154, 0.5634615384615385, 0.5634615384615385, 0.5615384615384615, 0.5615384615384615, 0.5615384615384615, 0.5615384615384615, 0.5596153846153846, 0.5596153846153846, 0.5596153846153846, 0.5596153846153846, 0.5596153846153846, 0.5576923076923077, 0.5557692307692308, 0.5557692307692308, 0.5576923076923077, 0.5557692307692308, 0.5557692307692308, 0.5557692307692308, 0.5538461538461539, 0.5557692307692308, 0.5557692307692308, 0.5557692307692308, 0.5538461538461539, 0.5538461538461539, 0.551923076923077, 0.55, 0.55, 0.55, 0.5480769230769231, 0.55, 0.551923076923077, 0.551923076923077, 0.551923076923077, 0.551923076923077, 0.55, 0.5480769230769231, 0.5461538461538461, 0.5461538461538461, 0.5480769230769231, 0.5480769230769231, 0.5480769230769231, 0.5480769230769231, 0.5480769230769231, 0.5480769230769231, 0.55, 0.551923076923077, 0.55, 0.5576923076923077, 0.5596153846153846, 0.5615384615384615, 0.551923076923077, 0.55, 0.55, 0.5538461538461539, 0.5557692307692308, 0.5634615384615385, 0.5634615384615385, 0.5634615384615385, 0.5653846153846154, 0.5634615384615385, 0.5634615384615385, 0.575, 0.5807692307692308, 0.5846153846153846, 0.6153846153846154, 0.6173076923076923, 0.6211538461538462, 0.6153846153846154, 0.6211538461538462, 0.6269230769230769, 0.6288461538461538, 0.6346153846153846, 0.6307692307692307, 0.6288461538461538, 0.6288461538461538, 0.6326923076923077, 0.6307692307692307, 0.6326923076923077, 0.6403846153846153, 0.6423076923076924, 0.6423076923076924, 0.6423076923076924, 0.6423076923076924, 0.65, 0.65, 0.6480769230769231, 0.6480769230769231, 0.65, 0.6519230769230769, 0.6576923076923077, 0.6596153846153846, 0.6576923076923077, 0.6596153846153846, 0.6653846153846154, 0.6711538461538461, 0.6711538461538461, 0.6730769230769231, 0.6730769230769231, 0.6730769230769231, 0.676923076923077, 0.6788461538461539, 0.6807692307692308, 0.6807692307692308, 0.6846153846153846, 0.6846153846153846, 0.6826923076923077, 0.6826923076923077, 0.6826923076923077, 0.6826923076923077, 0.6826923076923077, 0.6807692307692308, 0.6826923076923077, 0.6826923076923077, 0.6846153846153846, 0.6865384615384615, 0.6884615384615385, 0.6846153846153846, 0.6807692307692308, 0.6788461538461539, 0.6846153846153846, 0.6846153846153846, 0.6865384615384615, 0.6865384615384615, 0.6865384615384615, 0.6865384615384615, 0.6884615384615385, 0.6884615384615385, 0.6903846153846154, 0.6884615384615385, 0.6884615384615385, 0.6884615384615385, 0.6865384615384615, 0.6826923076923077, 0.6826923076923077, 0.6884615384615385, 0.6884615384615385, 0.6923076923076923, 0.6923076923076923, 0.6942307692307692, 0.6942307692307692, 0.6923076923076923, 0.6903846153846154, 0.6923076923076923, 0.6903846153846154, 0.6923076923076923, 0.6942307692307692, 0.6942307692307692, 0.698076923076923, 0.7, 0.6961538461538461, 0.6961538461538461, 0.6961538461538461, 0.7, 0.7057692307692308, 0.7019230769230769, 0.7019230769230769, 0.7038461538461539, 0.7076923076923077, 0.7057692307692308, 0.7057692307692308, 0.7019230769230769, 0.7057692307692308, 0.7076923076923077, 0.7096153846153846, 0.7096153846153846, 0.7076923076923077, 0.7076923076923077, 0.7019230769230769, 0.7038461538461539, 0.7, 0.7019230769230769, 0.698076923076923, 0.6942307692307692, 0.7038461538461539, 0.7038461538461539, 0.7, 0.7057692307692308, 0.7057692307692308, 0.7057692307692308, 0.7019230769230769, 0.7057692307692308, 0.7076923076923077, 0.7057692307692308, 0.7057692307692308, 0.7076923076923077, 0.7096153846153846, 0.7096153846153846, 0.7076923076923077, 0.7076923076923077, 0.7096153846153846, 0.7076923076923077, 0.7096153846153846, 0.7096153846153846, 0.7134615384615385, 0.7134615384615385, 0.7115384615384616, 0.7115384615384616, 0.7096153846153846, 0.7096153846153846, 0.7096153846153846, 0.7019230769230769, 0.698076923076923, 0.6961538461538461, 0.7019230769230769, 0.7, 0.6961538461538461, 0.6961538461538461, 0.698076923076923, 0.6903846153846154, 0.6942307692307692, 0.6961538461538461, 0.6923076923076923, 0.6865384615384615, 0.6884615384615385, 0.6884615384615385, 0.6903846153846154, 0.6884615384615385, 0.6826923076923077, 0.6884615384615385, 0.6865384615384615, 0.6846153846153846, 0.6846153846153846, 0.6846153846153846, 0.6826923076923077, 0.6807692307692308, 0.6846153846153846, 0.6826923076923077, 0.6730769230769231, 0.6730769230769231, 0.6615384615384615, 0.6596153846153846, 0.6538461538461539, 0.6519230769230769, 0.6480769230769231, 0.6480769230769231, 0.6326923076923077, 0.6269230769230769, 0.6346153846153846, 0.6307692307692307, 0.6307692307692307, 0.6365384615384615, 0.6384615384615384, 0.6365384615384615, 0.6346153846153846, 0.6365384615384615, 0.6365384615384615, 0.6384615384615384, 0.6346153846153846, 0.6403846153846153, 0.6403846153846153, 0.6423076923076924, 0.6461538461538462, 0.6442307692307693, 0.6403846153846153, 0.6403846153846153, 0.6365384615384615, 0.6384615384615384, 0.6365384615384615, 0.6365384615384615, 0.6403846153846153, 0.6461538461538462, 0.6442307692307693, 0.6480769230769231, 0.6442307692307693, 0.6442307692307693, 0.6403846153846153, 0.6403846153846153, 0.6442307692307693, 0.6442307692307693, 0.6423076923076924, 0.6403846153846153, 0.6403846153846153, 0.6403846153846153, 0.6403846153846153, 0.6403846153846153, 0.6384615384615384, 0.6365384615384615, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6365384615384615, 0.6307692307692307, 0.625, 0.625, 0.6307692307692307, 0.6192307692307693, 0.625, 0.6019230769230769, 0.6076923076923076, 0.6173076923076923, 0.625, 0.6038461538461538, 0.5057692307692307]\n",
      "Results successfully saved to results/input_sparsity/madelon/LassoNet/rep_2/madelon_LassoNet_rep2_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 2\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_dnn on madelon (repetition 2)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6847 - accuracy: 0.5808\n",
      "test acc for vanilla model is [0.6846663355827332, 0.5807692408561707]\n",
      "0.6846663355827332 0.5807692408561707\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 0.6304 - accuracy: 0.6788\n",
      "test acc for vanilla model is [0.6303988099098206, 0.6788461804389954]\n",
      "0.6303988099098206 0.6788461804389954\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.4433 - accuracy: 0.5673\n",
      "test acc for vanilla model is [1.4432895183563232, 0.567307710647583]\n",
      "1.4432895183563232 0.567307710647583\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1188 - accuracy: 0.5788\n",
      "test acc for vanilla model is [2.118845224380493, 0.5788461565971375]\n",
      "2.118845224380493 0.5788461565971375\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 73\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3384 - accuracy: 0.5442\n",
      "test acc for vanilla model is [2.3383865356445312, 0.5442307591438293]\n",
      "2.3383865356445312 0.5442307591438293\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 127\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.4545 - accuracy: 0.5327\n",
      "test acc for vanilla model is [2.454524040222168, 0.5326923131942749]\n",
      "2.454524040222168 0.5326923131942749\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 180\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4007 - accuracy: 0.5442\n",
      "test acc for vanilla model is [2.400696039199829, 0.5442307591438293]\n",
      "2.400696039199829 0.5442307591438293\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4800 - accuracy: 0.5538\n",
      "test acc for vanilla model is [2.4800004959106445, 0.5538461804389954]\n",
      "2.4800004959106445 0.5538461804389954\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2560 - accuracy: 0.5538\n",
      "test acc for vanilla model is [2.2560064792633057, 0.5538461804389954]\n",
      "2.2560064792633057 0.5538461804389954\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4232 - accuracy: 0.5231\n",
      "test acc for vanilla model is [2.423229932785034, 0.5230769515037537]\n",
      "2.423229932785034 0.5230769515037537\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3496 - accuracy: 0.5712\n",
      "test acc for vanilla model is [2.34961199760437, 0.5711538195610046]\n",
      "2.34961199760437 0.5711538195610046\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4751 - accuracy: 0.5269\n",
      "test acc for vanilla model is [2.4750728607177734, 0.5269230604171753]\n",
      "2.4750728607177734 0.5269230604171753\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 181\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.0867 - accuracy: 0.5654\n",
      "test acc for vanilla model is [2.0866873264312744, 0.5653846263885498]\n",
      "2.0866873264312744 0.5653846263885498\n",
      "Repetition 2: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.5807692408561707, 0.6788461804389954, 0.567307710647583, 0.5788461565971375, 0.5442307591438293, 0.5326923131942749, 0.5442307591438293, 0.5538461804389954, 0.5538461804389954, 0.5230769515037537, 0.5711538195610046, 0.5269230604171753, 0.5653846263885498]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_dnn/rep_2/madelon_HSIC_dnn_rep2_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 2\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_svm on madelon (repetition 2)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 2: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6057692307692307, 0.6423076923076924, 0.6134615384615385, 0.6307692307692307, 0.5884615384615385, 0.6, 0.5769230769230769, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_svm/rep_2/madelon_HSIC_svm_rep2_res.csv\n",
      "\n",
      "Starting repetition 3/5\n",
      "Loading dataset: madelon for repetition 3/5\n",
      "Loading dataset: madelon with one_hot = False for repetition 3\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running LassoNet on madelon (repetition 3)\n",
      "Repetition 3: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0020000000000000018, 0.006000000000000005, 0.006000000000000005, 0.006000000000000005, 0.010000000000000009, 0.008000000000000007, 0.010000000000000009, 0.016000000000000014, 0.018000000000000016, 0.026000000000000023, 0.030000000000000027, 0.03600000000000003, 0.040000000000000036, 0.05400000000000005, 0.06399999999999995, 0.06799999999999995, 0.07599999999999996, 0.08999999999999997, 0.09999999999999998, 0.11199999999999999, 0.12, 0.126, 0.128, 0.13, 0.13, 0.14, 0.14, 0.14600000000000002, 0.15600000000000003, 0.16400000000000003, 0.17000000000000004, 0.17600000000000005, 0.18799999999999994, 0.19199999999999995, 0.20399999999999996, 0.20999999999999996, 0.21799999999999997, 0.22399999999999998, 0.23199999999999998, 0.22999999999999998, 0.24, 0.25, 0.264, 0.266, 0.27, 0.268, 0.278, 0.28400000000000003, 0.28800000000000003, 0.29600000000000004, 0.30600000000000005, 0.31999999999999995, 0.32599999999999996, 0.33599999999999997, 0.33799999999999997, 0.33999999999999997, 0.344, 0.36, 0.372, 0.384, 0.392, 0.40800000000000003, 0.41600000000000004, 0.41800000000000004, 0.42400000000000004, 0.44399999999999995, 0.45599999999999996, 0.46599999999999997, 0.474, 0.494, 0.504, 0.51, 0.52, 0.53, 0.538, 0.548, 0.554, 0.554, 0.5660000000000001, 0.5800000000000001, 0.5900000000000001, 0.602, 0.618, 0.624, 0.632, 0.64, 0.6599999999999999, 0.6679999999999999, 0.6699999999999999, 0.6799999999999999, 0.6859999999999999, 0.694, 0.7, 0.712, 0.724, 0.732, 0.734, 0.74, 0.748, 0.762, 0.764, 0.772, 0.78, 0.784, 0.79, 0.8, 0.81, 0.8180000000000001, 0.836, 0.848, 0.854, 0.858, 0.866, 0.866, 0.866, 0.872, 0.88, 0.88, 0.878, 0.882, 0.892, 0.9, 0.908, 0.916, 0.926, 0.928, 0.9339999999999999, 0.9339999999999999, 0.938, 0.94, 0.942, 0.946, 0.954, 0.956, 0.96, 0.962, 0.962, 0.962, 0.964, 0.968, 0.97, 0.974, 0.974, 0.976, 0.976, 0.976, 0.98, 0.982, 0.984, 0.984, 0.986, 0.986, 0.986, 0.982, 0.982, 0.984, 0.984, 0.986, 0.988, 0.988, 0.986, 0.986, 0.986, 0.986, 0.986, 0.988, 0.982, 0.988, 0.988, 0.982, 0.986, 0.988, 0.992, 0.99, 0.99, 0.992, 0.992, 0.992, 0.992, 0.992, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 1.0] and test accuracy = [0.6, 0.5961538461538461, 0.5961538461538461, 0.5942307692307692, 0.5942307692307692, 0.5942307692307692, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5980769230769231, 0.6, 0.6, 0.6019230769230769, 0.6038461538461538, 0.6019230769230769, 0.6, 0.6, 0.6, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.6, 0.6038461538461538, 0.6038461538461538, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6057692307692307, 0.6038461538461538, 0.6019230769230769, 0.6, 0.6, 0.6, 0.5980769230769231, 0.5980769230769231, 0.5961538461538461, 0.5980769230769231, 0.5980769230769231, 0.5961538461538461, 0.5942307692307692, 0.5942307692307692, 0.5923076923076923, 0.5923076923076923, 0.5903846153846154, 0.5846153846153846, 0.5846153846153846, 0.5826923076923077, 0.5846153846153846, 0.5846153846153846, 0.5865384615384616, 0.5769230769230769, 0.573076923076923, 0.575, 0.5826923076923077, 0.5846153846153846, 0.5942307692307692, 0.6057692307692307, 0.5961538461538461, 0.6192307692307693, 0.6307692307692307, 0.6442307692307693, 0.6423076923076924, 0.65, 0.6480769230769231, 0.65, 0.65, 0.6519230769230769, 0.6538461538461539, 0.6557692307692308, 0.6576923076923077, 0.6557692307692308, 0.6615384615384615, 0.6653846153846154, 0.6673076923076923, 0.6711538461538461, 0.675, 0.676923076923077, 0.6826923076923077, 0.6865384615384615, 0.6807692307692308, 0.676923076923077, 0.6788461538461539, 0.6788461538461539, 0.6807692307692308, 0.6846153846153846, 0.6846153846153846, 0.6846153846153846, 0.6846153846153846, 0.6865384615384615, 0.6846153846153846, 0.6826923076923077, 0.6942307692307692, 0.7, 0.6961538461538461, 0.6961538461538461, 0.6961538461538461, 0.6942307692307692, 0.7, 0.7, 0.7, 0.698076923076923, 0.7, 0.7057692307692308, 0.7076923076923077, 0.7096153846153846, 0.7096153846153846, 0.7076923076923077, 0.7096153846153846, 0.7134615384615385, 0.7153846153846154, 0.7153846153846154, 0.7192307692307692, 0.7211538461538461, 0.7211538461538461, 0.7230769230769231, 0.7230769230769231, 0.7211538461538461, 0.7230769230769231, 0.7230769230769231, 0.725, 0.7192307692307692, 0.7192307692307692, 0.7192307692307692, 0.7211538461538461, 0.7269230769230769, 0.7269230769230769, 0.7269230769230769, 0.7269230769230769, 0.7288461538461538, 0.7307692307692307, 0.7288461538461538, 0.7288461538461538, 0.7269230769230769, 0.7269230769230769, 0.7269230769230769, 0.7269230769230769, 0.7288461538461538, 0.7307692307692307, 0.7326923076923076, 0.7326923076923076, 0.7307692307692307, 0.7288461538461538, 0.7307692307692307, 0.7307692307692307, 0.7269230769230769, 0.7326923076923076, 0.7384615384615385, 0.7384615384615385, 0.7423076923076923, 0.7403846153846154, 0.7365384615384616, 0.7384615384615385, 0.7365384615384616, 0.7326923076923076, 0.7423076923076923, 0.7423076923076923, 0.7346153846153847, 0.7403846153846154, 0.7365384615384616, 0.7365384615384616, 0.7365384615384616, 0.7384615384615385, 0.7346153846153847, 0.7288461538461538, 0.7326923076923076, 0.7269230769230769, 0.725, 0.7211538461538461, 0.7269230769230769, 0.7211538461538461, 0.7211538461538461, 0.7230769230769231, 0.7211538461538461, 0.7115384615384616, 0.7115384615384616, 0.7153846153846154, 0.7173076923076923, 0.7134615384615385, 0.7134615384615385, 0.7115384615384616, 0.7076923076923077, 0.7076923076923077, 0.7057692307692308, 0.7057692307692308, 0.7115384615384616, 0.7076923076923077, 0.6961538461538461, 0.6923076923076923, 0.6942307692307692, 0.6923076923076923, 0.6942307692307692, 0.6865384615384615, 0.6884615384615385, 0.6826923076923077, 0.6846153846153846, 0.6826923076923077, 0.6865384615384615, 0.6788461538461539, 0.6788461538461539, 0.676923076923077, 0.676923076923077, 0.675, 0.6711538461538461, 0.6692307692307692, 0.6711538461538461, 0.6673076923076923, 0.6634615384615384, 0.6596153846153846, 0.6596153846153846, 0.65, 0.6423076923076924, 0.65, 0.6442307692307693, 0.6423076923076924, 0.6384615384615384, 0.6288461538461538, 0.6211538461538462, 0.6173076923076923, 0.6211538461538462, 0.6192307692307693, 0.6173076923076923, 0.6134615384615385, 0.6211538461538462, 0.6211538461538462, 0.6230769230769231, 0.6307692307692307, 0.625, 0.6403846153846153, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6423076923076924, 0.6423076923076924, 0.6423076923076924, 0.6423076923076924, 0.6403846153846153, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6365384615384615, 0.6346153846153846, 0.6346153846153846, 0.6326923076923077, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6288461538461538, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6288461538461538, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6288461538461538, 0.6288461538461538, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6288461538461538, 0.6307692307692307, 0.6269230769230769, 0.6269230769230769, 0.6211538461538462, 0.6153846153846154, 0.6153846153846154, 0.6173076923076923, 0.6173076923076923, 0.6153846153846154, 0.6076923076923076, 0.5923076923076923, 0.5903846153846154, 0.5923076923076923, 0.6038461538461538, 0.6115384615384616, 0.6115384615384616, 0.5903846153846154, 0.47884615384615387]\n",
      "Results successfully saved to results/input_sparsity/madelon/LassoNet/rep_3/madelon_LassoNet_rep3_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 3\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_dnn on madelon (repetition 3)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6807 - accuracy: 0.5885\n",
      "test acc for vanilla model is [0.6807010769844055, 0.5884615182876587]\n",
      "0.6807010769844055 0.5884615182876587\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 0.7178 - accuracy: 0.6077\n",
      "test acc for vanilla model is [0.7178414463996887, 0.607692301273346]\n",
      "0.7178414463996887 0.607692301273346\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.7154 - accuracy: 0.5808\n",
      "test acc for vanilla model is [1.7153574228286743, 0.5807692408561707]\n",
      "1.7153574228286743 0.5807692408561707\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.7123 - accuracy: 0.6538\n",
      "test acc for vanilla model is [1.7123446464538574, 0.6538461446762085]\n",
      "1.7123446464538574 0.6538461446762085\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 73\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3023 - accuracy: 0.5385\n",
      "test acc for vanilla model is [2.302320957183838, 0.5384615659713745]\n",
      "2.302320957183838 0.5384615659713745\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 127\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3028 - accuracy: 0.5519\n",
      "test acc for vanilla model is [2.3027656078338623, 0.5519230961799622]\n",
      "2.3027656078338623 0.5519230961799622\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 180\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.2441 - accuracy: 0.5308\n",
      "test acc for vanilla model is [2.244138240814209, 0.5307692289352417]\n",
      "2.244138240814209 0.5307692289352417\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3989 - accuracy: 0.5423\n",
      "test acc for vanilla model is [2.398918867111206, 0.5423076748847961]\n",
      "2.398918867111206 0.5423076748847961\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2815 - accuracy: 0.5423\n",
      "test acc for vanilla model is [2.28145170211792, 0.5423076748847961]\n",
      "2.28145170211792 0.5423076748847961\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.2753 - accuracy: 0.5231\n",
      "test acc for vanilla model is [2.2753357887268066, 0.5230769515037537]\n",
      "2.2753357887268066 0.5230769515037537\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1984 - accuracy: 0.5173\n",
      "test acc for vanilla model is [2.198427200317383, 0.517307698726654]\n",
      "2.198427200317383 0.517307698726654\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1293 - accuracy: 0.5462\n",
      "test acc for vanilla model is [2.129255771636963, 0.5461538434028625]\n",
      "2.129255771636963 0.5461538434028625\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 192\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4677 - accuracy: 0.5269\n",
      "test acc for vanilla model is [2.467702865600586, 0.5269230604171753]\n",
      "2.467702865600586 0.5269230604171753\n",
      "Repetition 3: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.5884615182876587, 0.607692301273346, 0.5807692408561707, 0.6538461446762085, 0.5384615659713745, 0.5519230961799622, 0.5307692289352417, 0.5423076748847961, 0.5423076748847961, 0.5230769515037537, 0.517307698726654, 0.5461538434028625, 0.5269230604171753]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_dnn/rep_3/madelon_HSIC_dnn_rep3_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 3\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_svm on madelon (repetition 3)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 3: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6057692307692307, 0.6423076923076924, 0.6134615384615385, 0.6307692307692307, 0.5884615384615385, 0.6, 0.5769230769230769, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_svm/rep_3/madelon_HSIC_svm_rep3_res.csv\n",
      "\n",
      "Starting repetition 4/5\n",
      "Loading dataset: madelon for repetition 4/5\n",
      "Loading dataset: madelon with one_hot = False for repetition 4\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running LassoNet on madelon (repetition 4)\n",
      "Repetition 4: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0020000000000000018, 0.010000000000000009, 0.014000000000000012, 0.01200000000000001, 0.0040000000000000036, 0.010000000000000009, 0.010000000000000009, 0.014000000000000012, 0.02200000000000002, 0.02200000000000002, 0.028000000000000025, 0.03400000000000003, 0.04200000000000004, 0.04400000000000004, 0.052000000000000046, 0.05800000000000005, 0.06000000000000005, 0.07399999999999995, 0.08199999999999996, 0.08399999999999996, 0.08999999999999997, 0.09199999999999997, 0.09999999999999998, 0.10199999999999998, 0.10599999999999998, 0.10999999999999999, 0.11599999999999999, 0.124, 0.134, 0.14, 0.14800000000000002, 0.15800000000000003, 0.17000000000000004, 0.17200000000000004, 0.18000000000000005, 0.18999999999999995, 0.19599999999999995, 0.19999999999999996, 0.20199999999999996, 0.20999999999999996, 0.21599999999999997, 0.22199999999999998, 0.236, 0.236, 0.252, 0.256, 0.266, 0.276, 0.278, 0.28600000000000003, 0.29200000000000004, 0.30200000000000005, 0.31799999999999995, 0.32799999999999996, 0.32599999999999996, 0.346, 0.356, 0.364, 0.372, 0.384, 0.388, 0.402, 0.41400000000000003, 0.42200000000000004, 0.42800000000000005, 0.43799999999999994, 0.45399999999999996, 0.45999999999999996, 0.47, 0.48, 0.49, 0.498, 0.506, 0.51, 0.518, 0.532, 0.54, 0.548, 0.5700000000000001, 0.5740000000000001, 0.5920000000000001, 0.596, 0.604, 0.618, 0.622, 0.634, 0.636, 0.644, 0.65, 0.6579999999999999, 0.6639999999999999, 0.6759999999999999, 0.696, 0.708, 0.718, 0.726, 0.738, 0.746, 0.76, 0.77, 0.782, 0.786, 0.792, 0.8, 0.808, 0.812, 0.8220000000000001, 0.83, 0.836, 0.844, 0.844, 0.846, 0.858, 0.858, 0.856, 0.86, 0.862, 0.864, 0.874, 0.884, 0.892, 0.9, 0.908, 0.918, 0.922, 0.9299999999999999, 0.9319999999999999, 0.938, 0.942, 0.94, 0.94, 0.944, 0.948, 0.956, 0.96, 0.962, 0.964, 0.966, 0.97, 0.97, 0.972, 0.974, 0.974, 0.974, 0.976, 0.98, 0.98, 0.98, 0.98, 0.98, 0.978, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.982, 0.98, 0.984, 0.982, 0.982, 0.984, 0.982, 0.984, 0.986, 0.988, 0.988, 0.988, 0.988, 0.994, 0.992, 0.994, 0.994, 0.994, 0.994, 0.994, 0.996, 0.994, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.998, 0.996, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 0.998, 1.0] and test accuracy = [0.5846153846153846, 0.5942307692307692, 0.5942307692307692, 0.5923076923076923, 0.5923076923076923, 0.5923076923076923, 0.5942307692307692, 0.5942307692307692, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5942307692307692, 0.5942307692307692, 0.5942307692307692, 0.5961538461538461, 0.5980769230769231, 0.5980769230769231, 0.5961538461538461, 0.5961538461538461, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5961538461538461, 0.5980769230769231, 0.6, 0.6, 0.5961538461538461, 0.5961538461538461, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.6, 0.6, 0.6, 0.5980769230769231, 0.6, 0.6, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.5961538461538461, 0.5961538461538461, 0.5980769230769231, 0.5980769230769231, 0.6, 0.6, 0.5980769230769231, 0.5980769230769231, 0.5980769230769231, 0.6, 0.6, 0.6019230769230769, 0.6019230769230769, 0.6038461538461538, 0.5980769230769231, 0.5923076923076923, 0.6, 0.5980769230769231, 0.6019230769230769, 0.5980769230769231, 0.5980769230769231, 0.5942307692307692, 0.5961538461538461, 0.5980769230769231, 0.6019230769230769, 0.5980769230769231, 0.6019230769230769, 0.6134615384615385, 0.6269230769230769, 0.6365384615384615, 0.6461538461538462, 0.6557692307692308, 0.65, 0.6480769230769231, 0.6538461538461539, 0.6576923076923077, 0.6576923076923077, 0.6576923076923077, 0.6557692307692308, 0.6596153846153846, 0.6557692307692308, 0.6596153846153846, 0.6596153846153846, 0.6615384615384615, 0.6653846153846154, 0.6634615384615384, 0.6711538461538461, 0.6692307692307692, 0.6692307692307692, 0.6730769230769231, 0.676923076923077, 0.675, 0.6730769230769231, 0.6730769230769231, 0.676923076923077, 0.676923076923077, 0.6788461538461539, 0.6826923076923077, 0.6826923076923077, 0.6903846153846154, 0.6903846153846154, 0.6903846153846154, 0.6884615384615385, 0.6903846153846154, 0.7, 0.7057692307692308, 0.7057692307692308, 0.7115384615384616, 0.7153846153846154, 0.7134615384615385, 0.7173076923076923, 0.7134615384615385, 0.7153846153846154, 0.7153846153846154, 0.7173076923076923, 0.7173076923076923, 0.7173076923076923, 0.7211538461538461, 0.7211538461538461, 0.725, 0.7288461538461538, 0.725, 0.7269230769230769, 0.7384615384615385, 0.7384615384615385, 0.7384615384615385, 0.7384615384615385, 0.7384615384615385, 0.7384615384615385, 0.7403846153846154, 0.7442307692307693, 0.7403846153846154, 0.7423076923076923, 0.7442307692307693, 0.7423076923076923, 0.7442307692307693, 0.7442307692307693, 0.7384615384615385, 0.7365384615384616, 0.7365384615384616, 0.7346153846153847, 0.7307692307692307, 0.7307692307692307, 0.7307692307692307, 0.7346153846153847, 0.7365384615384616, 0.7403846153846154, 0.7403846153846154, 0.7423076923076923, 0.7403846153846154, 0.7442307692307693, 0.7461538461538462, 0.7461538461538462, 0.7403846153846154, 0.7384615384615385, 0.7403846153846154, 0.7384615384615385, 0.7384615384615385, 0.7365384615384616, 0.7365384615384616, 0.7365384615384616, 0.7423076923076923, 0.7384615384615385, 0.7365384615384616, 0.7326923076923076, 0.7326923076923076, 0.7346153846153847, 0.7365384615384616, 0.7288461538461538, 0.7346153846153847, 0.7326923076923076, 0.7346153846153847, 0.7307692307692307, 0.7269230769230769, 0.7230769230769231, 0.7230769230769231, 0.7211538461538461, 0.7173076923076923, 0.7230769230769231, 0.7192307692307692, 0.7096153846153846, 0.7134615384615385, 0.7096153846153846, 0.7038461538461539, 0.7057692307692308, 0.7038461538461539, 0.698076923076923, 0.7, 0.7, 0.7019230769230769, 0.698076923076923, 0.6923076923076923, 0.6884615384615385, 0.6942307692307692, 0.6961538461538461, 0.6884615384615385, 0.6884615384615385, 0.6884615384615385, 0.6884615384615385, 0.6865384615384615, 0.6826923076923077, 0.6807692307692308, 0.6807692307692308, 0.676923076923077, 0.6788461538461539, 0.6788461538461539, 0.6788461538461539, 0.675, 0.6692307692307692, 0.6711538461538461, 0.6653846153846154, 0.6615384615384615, 0.6538461538461539, 0.6519230769230769, 0.6423076923076924, 0.6365384615384615, 0.6384615384615384, 0.6269230769230769, 0.6288461538461538, 0.6211538461538462, 0.6192307692307693, 0.6153846153846154, 0.6173076923076923, 0.6153846153846154, 0.6076923076923076, 0.6115384615384616, 0.6115384615384616, 0.6096153846153847, 0.6153846153846154, 0.6076923076923076, 0.6115384615384616, 0.6038461538461538, 0.6038461538461538, 0.6096153846153847, 0.6076923076923076, 0.6134615384615385, 0.6134615384615385, 0.6153846153846154, 0.6365384615384615, 0.6326923076923077, 0.6307692307692307, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6307692307692307, 0.6288461538461538, 0.6307692307692307, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6269230769230769, 0.6288461538461538, 0.6288461538461538, 0.6211538461538462, 0.625, 0.625, 0.6288461538461538, 0.6307692307692307, 0.625, 0.6269230769230769, 0.625, 0.6269230769230769, 0.6365384615384615, 0.6326923076923077, 0.6365384615384615, 0.6211538461538462, 0.5]\n",
      "Results successfully saved to results/input_sparsity/madelon/LassoNet/rep_4/madelon_LassoNet_rep4_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 4\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_dnn on madelon (repetition 4)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 0.6931 - accuracy: 0.4981\n",
      "test acc for vanilla model is [0.6931481957435608, 0.4980769157409668]\n",
      "0.6931481957435608 0.4980769157409668\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6791 - accuracy: 0.6212\n",
      "test acc for vanilla model is [0.6791387796401978, 0.6211538314819336]\n",
      "0.6791387796401978 0.6211538314819336\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.4248 - accuracy: 0.5769\n",
      "test acc for vanilla model is [1.4247627258300781, 0.5769230723381042]\n",
      "1.4247627258300781 0.5769230723381042\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.3048 - accuracy: 0.5500\n",
      "test acc for vanilla model is [2.3047945499420166, 0.550000011920929]\n",
      "2.3047945499420166 0.550000011920929\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 73\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.2167 - accuracy: 0.5788\n",
      "test acc for vanilla model is [2.216674327850342, 0.5788461565971375]\n",
      "2.216674327850342 0.5788461565971375\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 127\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.0164 - accuracy: 0.5885\n",
      "test acc for vanilla model is [2.016407012939453, 0.5884615182876587]\n",
      "2.016407012939453 0.5884615182876587\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 180\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.9930 - accuracy: 0.5962\n",
      "test acc for vanilla model is [1.9929537773132324, 0.5961538553237915]\n",
      "1.9929537773132324 0.5961538553237915\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1776 - accuracy: 0.5827\n",
      "test acc for vanilla model is [2.1775877475738525, 0.5826923251152039]\n",
      "2.1775877475738525 0.5826923251152039\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2092 - accuracy: 0.5500\n",
      "test acc for vanilla model is [2.2092204093933105, 0.550000011920929]\n",
      "2.2092204093933105 0.550000011920929\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.1791 - accuracy: 0.5788\n",
      "test acc for vanilla model is [2.1791155338287354, 0.5788461565971375]\n",
      "2.1791155338287354 0.5788461565971375\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2108 - accuracy: 0.5462\n",
      "test acc for vanilla model is [2.2108330726623535, 0.5461538434028625]\n",
      "2.2108330726623535 0.5461538434028625\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.1487 - accuracy: 0.5558\n",
      "test acc for vanilla model is [2.1487061977386475, 0.5557692050933838]\n",
      "2.1487061977386475 0.5557692050933838\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 184\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.2062 - accuracy: 0.5654\n",
      "test acc for vanilla model is [2.2061617374420166, 0.5653846263885498]\n",
      "2.2061617374420166 0.5653846263885498\n",
      "Repetition 4: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.4980769157409668, 0.6211538314819336, 0.5769230723381042, 0.550000011920929, 0.5788461565971375, 0.5884615182876587, 0.5961538553237915, 0.5826923251152039, 0.550000011920929, 0.5788461565971375, 0.5461538434028625, 0.5557692050933838, 0.5653846263885498]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_dnn/rep_4/madelon_HSIC_dnn_rep4_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 4\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_svm on madelon (repetition 4)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 4: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6057692307692307, 0.6423076923076924, 0.6134615384615385, 0.6307692307692307, 0.5884615384615385, 0.6, 0.5769230769230769, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_svm/rep_4/madelon_HSIC_svm_rep4_res.csv\n",
      "\n",
      "Starting repetition 5/5\n",
      "Loading dataset: madelon for repetition 5/5\n",
      "Loading dataset: madelon with one_hot = False for repetition 5\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running LassoNet on madelon (repetition 5)\n",
      "Repetition 5: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.006000000000000005, 0.006000000000000005, 0.01200000000000001, 0.006000000000000005, 0.0020000000000000018, 0.0020000000000000018, 0.0020000000000000018, 0.0040000000000000036, 0.0040000000000000036, 0.008000000000000007, 0.01200000000000001, 0.018000000000000016, 0.02200000000000002, 0.026000000000000023, 0.03200000000000003, 0.04400000000000004, 0.052000000000000046, 0.062000000000000055, 0.07599999999999996, 0.08599999999999997, 0.09399999999999997, 0.09399999999999997, 0.10799999999999998, 0.12, 0.124, 0.132, 0.14200000000000002, 0.14800000000000002, 0.15400000000000003, 0.16200000000000003, 0.17400000000000004, 0.18400000000000005, 0.19599999999999995, 0.20799999999999996, 0.21599999999999997, 0.21799999999999997, 0.22599999999999998, 0.22599999999999998, 0.22799999999999998, 0.23399999999999999, 0.24, 0.242, 0.256, 0.266, 0.28, 0.30000000000000004, 0.31399999999999995, 0.31799999999999995, 0.31999999999999995, 0.33799999999999997, 0.346, 0.354, 0.366, 0.386, 0.396, 0.406, 0.41000000000000003, 0.41400000000000003, 0.42000000000000004, 0.42400000000000004, 0.43799999999999994, 0.44999999999999996, 0.45799999999999996, 0.46799999999999997, 0.472, 0.484, 0.496, 0.502, 0.512, 0.52, 0.526, 0.532, 0.546, 0.5640000000000001, 0.5700000000000001, 0.5720000000000001, 0.5820000000000001, 0.5920000000000001, 0.608, 0.622, 0.628, 0.636, 0.642, 0.65, 0.6679999999999999, 0.6699999999999999, 0.6779999999999999, 0.6859999999999999, 0.696, 0.7, 0.704, 0.71, 0.72, 0.724, 0.738, 0.746, 0.762, 0.774, 0.782, 0.792, 0.8, 0.806, 0.812, 0.8160000000000001, 0.8200000000000001, 0.8240000000000001, 0.8280000000000001, 0.842, 0.854, 0.85, 0.852, 0.856, 0.862, 0.87, 0.87, 0.874, 0.888, 0.898, 0.916, 0.926, 0.9339999999999999, 0.9359999999999999, 0.9359999999999999, 0.94, 0.942, 0.946, 0.948, 0.952, 0.954, 0.954, 0.954, 0.956, 0.958, 0.962, 0.964, 0.966, 0.97, 0.972, 0.974, 0.974, 0.974, 0.978, 0.978, 0.974, 0.976, 0.976, 0.976, 0.974, 0.976, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.984, 0.986, 0.986, 0.982, 0.982, 0.982, 0.986, 0.982, 0.984, 0.99, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.992, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.994, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.996, 0.998, 1.0] and test accuracy = [0.5653846153846154, 0.5788461538461539, 0.5788461538461539, 0.5788461538461539, 0.5788461538461539, 0.5788461538461539, 0.5807692307692308, 0.5788461538461539, 0.5788461538461539, 0.5807692307692308, 0.5807692307692308, 0.5807692307692308, 0.5788461538461539, 0.5788461538461539, 0.5788461538461539, 0.5788461538461539, 0.575, 0.575, 0.573076923076923, 0.573076923076923, 0.5711538461538461, 0.5711538461538461, 0.5692307692307692, 0.573076923076923, 0.573076923076923, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.573076923076923, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5692307692307692, 0.5673076923076923, 0.5673076923076923, 0.5673076923076923, 0.5673076923076923, 0.5653846153846154, 0.5653846153846154, 0.5653846153846154, 0.5653846153846154, 0.5673076923076923, 0.5653846153846154, 0.5653846153846154, 0.5653846153846154, 0.5673076923076923, 0.5692307692307692, 0.5711538461538461, 0.5634615384615385, 0.5653846153846154, 0.5596153846153846, 0.5596153846153846, 0.5596153846153846, 0.5557692307692308, 0.5538461538461539, 0.5557692307692308, 0.551923076923077, 0.55, 0.5538461538461539, 0.5538461538461539, 0.5576923076923077, 0.5557692307692308, 0.5596153846153846, 0.5634615384615385, 0.5634615384615385, 0.5692307692307692, 0.5769230769230769, 0.573076923076923, 0.5846153846153846, 0.6, 0.6038461538461538, 0.6076923076923076, 0.6173076923076923, 0.6211538461538462, 0.625, 0.6288461538461538, 0.6365384615384615, 0.6346153846153846, 0.6365384615384615, 0.6384615384615384, 0.6365384615384615, 0.6423076923076924, 0.6423076923076924, 0.6442307692307693, 0.6442307692307693, 0.6442307692307693, 0.6423076923076924, 0.6423076923076924, 0.6442307692307693, 0.6403846153846153, 0.6384615384615384, 0.6423076923076924, 0.6423076923076924, 0.6403846153846153, 0.6423076923076924, 0.6461538461538462, 0.6461538461538462, 0.6480769230769231, 0.6538461538461539, 0.6519230769230769, 0.6557692307692308, 0.6711538461538461, 0.6692307692307692, 0.6692307692307692, 0.6653846153846154, 0.6692307692307692, 0.675, 0.6807692307692308, 0.6846153846153846, 0.6923076923076923, 0.6942307692307692, 0.6942307692307692, 0.6942307692307692, 0.6923076923076923, 0.698076923076923, 0.6961538461538461, 0.698076923076923, 0.698076923076923, 0.6961538461538461, 0.698076923076923, 0.7019230769230769, 0.7, 0.7, 0.7, 0.7076923076923077, 0.7057692307692308, 0.7057692307692308, 0.7038461538461539, 0.7057692307692308, 0.7076923076923077, 0.7038461538461539, 0.7096153846153846, 0.7115384615384616, 0.7134615384615385, 0.7134615384615385, 0.7057692307692308, 0.7019230769230769, 0.7057692307692308, 0.7019230769230769, 0.7057692307692308, 0.7096153846153846, 0.7096153846153846, 0.7038461538461539, 0.7038461538461539, 0.7038461538461539, 0.7019230769230769, 0.7057692307692308, 0.7057692307692308, 0.7038461538461539, 0.7038461538461539, 0.7038461538461539, 0.7, 0.6961538461538461, 0.7019230769230769, 0.7038461538461539, 0.7, 0.7038461538461539, 0.7057692307692308, 0.7115384615384616, 0.7115384615384616, 0.7115384615384616, 0.7076923076923077, 0.7057692307692308, 0.7057692307692308, 0.7076923076923077, 0.7076923076923077, 0.7076923076923077, 0.7076923076923077, 0.7076923076923077, 0.7057692307692308, 0.7076923076923077, 0.7019230769230769, 0.6961538461538461, 0.7, 0.7038461538461539, 0.6961538461538461, 0.6942307692307692, 0.6923076923076923, 0.6903846153846154, 0.6865384615384615, 0.6846153846153846, 0.6826923076923077, 0.6730769230769231, 0.675, 0.6788461538461539, 0.675, 0.6788461538461539, 0.676923076923077, 0.6807692307692308, 0.675, 0.675, 0.675, 0.6807692307692308, 0.6788461538461539, 0.675, 0.676923076923077, 0.6788461538461539, 0.6788461538461539, 0.676923076923077, 0.676923076923077, 0.675, 0.676923076923077, 0.6788461538461539, 0.675, 0.675, 0.6730769230769231, 0.6730769230769231, 0.6788461538461539, 0.6711538461538461, 0.6730769230769231, 0.6634615384615384, 0.6615384615384615, 0.6596153846153846, 0.6538461538461539, 0.6596153846153846, 0.6596153846153846, 0.6576923076923077, 0.6596153846153846, 0.6557692307692308, 0.6538461538461539, 0.6461538461538462, 0.6423076923076924, 0.6423076923076924, 0.6403846153846153, 0.6384615384615384, 0.6365384615384615, 0.6326923076923077, 0.6346153846153846, 0.6288461538461538, 0.6211538461538462, 0.6230769230769231, 0.6173076923076923, 0.6230769230769231, 0.6134615384615385, 0.6384615384615384, 0.6346153846153846, 0.6384615384615384, 0.6403846153846153, 0.6403846153846153, 0.6403846153846153, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6384615384615384, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6365384615384615, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6346153846153846, 0.6365384615384615, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6326923076923077, 0.6307692307692307, 0.6307692307692307, 0.6307692307692307, 0.6288461538461538, 0.6288461538461538, 0.6288461538461538, 0.6269230769230769, 0.625, 0.6269230769230769, 0.6153846153846154, 0.6173076923076923, 0.6019230769230769, 0.6057692307692307, 0.5846153846153846, 0.5769230769230769, 0.5538461538461539, 0.5423076923076923, 0.5, 0.5480769230769231, 0.5557692307692308, 0.5038461538461538, 0.49230769230769234, 0.49230769230769234]\n",
      "Results successfully saved to results/input_sparsity/madelon/LassoNet/rep_5/madelon_LassoNet_rep5_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 5\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_dnn on madelon (repetition 5)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6722 - accuracy: 0.5942\n",
      "test acc for vanilla model is [0.6722131967544556, 0.5942307710647583]\n",
      "0.6722131967544556 0.5942307710647583\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 0.6382 - accuracy: 0.6731\n",
      "test acc for vanilla model is [0.6382294297218323, 0.6730769276618958]\n",
      "0.6382294297218323 0.6730769276618958\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 1.3982 - accuracy: 0.6365\n",
      "test acc for vanilla model is [1.398185133934021, 0.6365384459495544]\n",
      "1.398185133934021 0.6365384459495544\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 1.9780 - accuracy: 0.5962\n",
      "test acc for vanilla model is [1.9779942035675049, 0.5961538553237915]\n",
      "1.9779942035675049 0.5961538553237915\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 73\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2686 - accuracy: 0.5538\n",
      "test acc for vanilla model is [2.26859712600708, 0.5538461804389954]\n",
      "2.26859712600708 0.5538461804389954\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 127\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.3318 - accuracy: 0.5577\n",
      "test acc for vanilla model is [2.3317534923553467, 0.557692289352417]\n",
      "2.3317534923553467 0.557692289352417\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 180\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.4419 - accuracy: 0.5250\n",
      "test acc for vanilla model is [2.441931962966919, 0.5249999761581421]\n",
      "2.441931962966919 0.5249999761581421\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.0451 - accuracy: 0.5846\n",
      "test acc for vanilla model is [2.0451107025146484, 0.5846154093742371]\n",
      "2.0451107025146484 0.5846154093742371\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.1357 - accuracy: 0.5538\n",
      "test acc for vanilla model is [2.135740041732788, 0.5538461804389954]\n",
      "2.135740041732788 0.5538461804389954\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 2ms/step - loss: 2.2624 - accuracy: 0.5500\n",
      "test acc for vanilla model is [2.262394666671753, 0.550000011920929]\n",
      "2.262394666671753 0.550000011920929\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.1389 - accuracy: 0.5712\n",
      "test acc for vanilla model is [2.138864040374756, 0.5711538195610046]\n",
      "2.138864040374756 0.5711538195610046\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.3201 - accuracy: 0.5692\n",
      "test acc for vanilla model is [2.320051908493042, 0.5692307949066162]\n",
      "2.320051908493042 0.5692307949066162\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 202\n",
      "17/17 [==============================] - 0s 1ms/step - loss: 2.5908 - accuracy: 0.5519\n",
      "test acc for vanilla model is [2.5907669067382812, 0.5519230961799622]\n",
      "2.5907669067382812 0.5519230961799622\n",
      "Repetition 5: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.5942307710647583, 0.6730769276618958, 0.6365384459495544, 0.5961538553237915, 0.5538461804389954, 0.557692289352417, 0.5249999761581421, 0.5846154093742371, 0.5538461804389954, 0.550000011920929, 0.5711538195610046, 0.5692307949066162, 0.5519230961799622]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_dnn/rep_5/madelon_HSIC_dnn_rep5_res.csv\n",
      "Loading dataset: madelon with one_hot = False for repetition 5\n",
      "Keys in .mat file: dict_keys(['__header__', '__version__', '__globals__', 'Y', 'X'])\n",
      "Unique labels before preprocessing: [-1  1]\n",
      "Unique train labels after processing: [0 1]\n",
      "Unique test labels after processing: [0 1]\n",
      "X_train shape: (2080, 500), Y_train shape: (2080,)\n",
      "X_test shape: (520, 500), Y_test shape: (520,)\n",
      "Running HSIC_svm on madelon (repetition 5)\n",
      "Sequence of features is [  1   5  10  20  73 127 180 233 287 340 393 447 500]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 73\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 127\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 180\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 233\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 287\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 340\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 393\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 447\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 500\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 5: sparsity = [0.998, 0.99, 0.98, 0.96, 0.854, 0.746, 0.64, 0.534, 0.42600000000000005, 0.31999999999999995, 0.21399999999999997, 0.10599999999999998, 0.0] and test accuracy = [0.6057692307692307, 0.6423076923076924, 0.6134615384615385, 0.6307692307692307, 0.5884615384615385, 0.6, 0.5769230769230769, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461, 0.5711538461538461]\n",
      "Results successfully saved to results/input_sparsity/madelon/HSIC_svm/rep_5/madelon_HSIC_svm_rep5_res.csv\n",
      "\n",
      "Combined results saved to results/input_sparsity/madelon/all_results_summary.csv\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_dnn, hsic_svm, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/madelon/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"LassoNet\", lassoNet, False),\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        # Check if results already exist\n",
    "        if os.path.exists(result_filename):\n",
    "            print(f\"Results already exist for {method_name} on {dataset_name} (repetition {rep + 1}). Skipping...\")\n",
    "            continue\n",
    "            \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"madelon\": load_madelon\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'madelon')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
