{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T13:36:42.057931Z",
     "iopub.status.busy": "2025-01-20T13:36:42.057192Z",
     "iopub.status.idle": "2025-01-20T13:36:58.488839Z",
     "shell.execute_reply": "2025-01-20T13:36:58.488124Z",
     "shell.execute_reply.started": "2025-01-20T13:36:42.057868Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting pyHSICLasso\n",
      "  Downloading pyHSICLasso-1.4.2-py2.py3-none-any.whl (14 kB)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.23.1)\n",
      "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: future in /usr/lib/python3/dist-packages (from pyHSICLasso) (0.18.2)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.8.1)\n",
      "Requirement already satisfied: seaborn in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (0.11.2)\n",
      "Collecting pytest\n",
      "  Downloading pytest-8.3.4-py3-none-any.whl (343 kB)\n",
      "     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 343.1/343.1 kB 37.4 MB/s eta 0:00:00\n",
      "Requirement already satisfied: six in /usr/lib/python3/dist-packages (from pyHSICLasso) (1.14.0)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (1.1.0)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from pyHSICLasso) (3.5.2)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (3.0.9)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (1.4.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (9.2.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (0.11.0)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (2.8.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (21.3)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->pyHSICLasso) (4.34.4)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->pyHSICLasso) (2022.1)\n",
      "Collecting exceptiongroup>=1.0.0rc8\n",
      "  Downloading exceptiongroup-1.2.2-py3-none-any.whl (16 kB)\n",
      "Collecting pluggy<2,>=1.5\n",
      "  Downloading pluggy-1.5.0-py3-none-any.whl (20 kB)\n",
      "Collecting iniconfig\n",
      "  Downloading iniconfig-2.0.0-py3-none-any.whl (5.9 kB)\n",
      "Collecting tomli>=1\n",
      "  Downloading tomli-2.2.1-py3-none-any.whl (14 kB)\n",
      "Installing collected packages: tomli, pluggy, iniconfig, exceptiongroup, pytest, pyHSICLasso\n",
      "Successfully installed exceptiongroup-1.2.2 iniconfig-2.0.0 pluggy-1.5.0 pyHSICLasso-1.4.2 pytest-8.3.4 tomli-2.2.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting lassonet\n",
      "  Downloading lassonet-0.0.20-py3-none-any.whl (21 kB)\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from lassonet) (4.64.0)\n",
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from lassonet) (3.5.2)\n",
      "Collecting sortedcontainers\n",
      "  Downloading sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)\n",
      "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.1.1)\n",
      "Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.9/dist-packages (from lassonet) (1.12.0+cu116)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.11->lassonet) (4.3.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (0.11.0)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (3.0.9)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.4.3)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (9.2.0)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (4.34.4)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (2.8.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (21.3)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from matplotlib->lassonet) (1.23.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (3.1.0)\n",
      "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.1.0)\n",
      "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/dist-packages (from scikit-learn->lassonet) (1.8.1)\n",
      "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->lassonet) (1.14.0)\n",
      "Installing collected packages: sortedcontainers, lassonet\n",
      "Successfully installed lassonet-0.0.20 sortedcontainers-2.4.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defining configs successful!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler \n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, StrHadamardDenseV2\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet300100, HadamardLeNet300100,\\\n",
    "                         InpHadamardLeNet300100, hsic_dnn, hsic_svm, lassoNet\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot, compute_input_sparsity\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='ones' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "#INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 100 #200\n",
    "INIT_LR            = 0.1 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 10\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 10\n",
    "FINE_GRACE         = 20\n",
    "MINACC             = (1 / CLASS_NUM) + 0.01\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = './results/input_sparsity/FMNIST'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    #1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    #4e-4,\n",
    "    5e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    8e-3,\n",
    "    1e-2,\n",
    "    1.5e-2,\n",
    "    2e-2,\n",
    "    2.5e-2,\n",
    "    3e-2,\n",
    "    3.5e-2,\n",
    "    4e-2,\n",
    "    4.5e-2,\n",
    "    5e-2,\n",
    "    5.5e-2,\n",
    "    6e-2,\n",
    "    6.5e-2,\n",
    "    7e-2,\n",
    "    8e-2,\n",
    "    9e-2,\n",
    "    9.5e-2,\n",
    "    1e-1,\n",
    "    1.25e-1,\n",
    "    1.5e-1,\n",
    "    1.75e-1,\n",
    "    2e-1,\n",
    "    5e-1,\n",
    "    7e-1,\n",
    "    1,\n",
    "    2\n",
    "    #5\n",
    "]\n",
    "\n",
    "#LAMBDA_LIST.reverse()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T13:37:17.440090Z",
     "iopub.status.busy": "2025-01-20T13:37:17.439093Z",
     "iopub.status.idle": "2025-01-20T13:37:24.200328Z",
     "shell.execute_reply": "2025-01-20T13:37:24.199808Z",
     "shell.execute_reply.started": "2025-01-20T13:37:17.440061Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n",
      "29515/29515 [==============================] - 0s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n",
      "26421880/26421880 [==============================] - 0s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n",
      "5148/5148 [==============================] - 0s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n",
      "4422102/4422102 [==============================] - 0s 0us/step\n",
      "x_train shape: (60000, 784), y_train shape: (60000, 10)\n",
      "x_test shape: (10000, 784), y_test shape: (10000, 10)\n",
      "Normalized Training Set Mean and SD: [-3.24146985e-03  3.50020714e-02  8.90986398e-02 -4.09977557e-03\n",
      " -4.60193958e-03 -6.55798614e-03 -2.84731085e-03 -1.16354357e-02\n",
      " -2.95641571e-02 -1.35834853e-03 -2.62191752e-03 -4.16113576e-03\n",
      "  2.39855493e-03 -1.66834856e-03  3.14865564e-03 -2.14544198e-04\n",
      " -6.56517595e-03 -7.31649483e-03 -4.51138802e-03  2.06359359e-03\n",
      " -2.91451986e-04 -9.64563899e-03 -1.72601989e-03  7.53892818e-04\n",
      " -1.55366980e-03  2.08137129e-02  1.08526191e-02  3.57899405e-02\n",
      "  3.76836769e-03  4.23307437e-03  1.31419199e-02 -1.25490846e-02\n",
      " -1.18861953e-02 -9.49829724e-03 -3.21947746e-02 -1.88831203e-02\n",
      " -7.12603796e-03  1.03516066e-02  1.64518468e-02  1.68899540e-02\n",
      "  1.84550378e-02  2.45181005e-02  2.35549156e-02  1.57501064e-02\n",
      "  4.95406287e-03  1.62830036e-02  2.16211341e-02  1.20076323e-02\n",
      " -4.41346411e-03  8.65958538e-03  1.57369561e-02  1.87500492e-02\n",
      "  1.80054680e-02  1.23995300e-02  9.10846982e-03  9.73001216e-03\n",
      "  7.49451807e-03  7.73095991e-03 -7.68886693e-03 -1.76386852e-02\n",
      " -1.36168897e-02 -3.20894197e-02 -9.53118689e-03  8.97258148e-03\n",
      "  1.13728810e-02  5.11359936e-03  8.43429100e-03  2.33553145e-02\n",
      "  1.24908397e-02  7.23513169e-03  5.25113055e-03  7.02910777e-03\n",
      "  1.67902522e-02  1.41728567e-02  2.00983752e-02  1.68732293e-02\n",
      "  7.41685508e-03  1.08532904e-05  1.16455620e-02  9.77592077e-04\n",
      "  7.38237100e-03  5.60628669e-03  1.18602896e-02  9.36703477e-03\n",
      " -4.71358048e-03  1.81554165e-03 -2.17479113e-02 -1.36672100e-02\n",
      " -2.12533567e-02 -1.13295233e-02  6.89093163e-03  1.59896556e-02\n",
      "  1.38822496e-02  8.67076125e-03  1.97429340e-02  1.00746788e-02\n",
      "  6.69226097e-03  1.54435281e-02  1.88798215e-02  1.07481321e-02\n",
      "  9.39126965e-03  1.79504193e-02  1.99522097e-02  1.92106236e-02\n",
      "  1.19977500e-02  6.41437341e-03  3.76900542e-04  6.62188884e-03\n",
      "  1.12041933e-02  1.54783539e-02  6.16374006e-03 -1.34121208e-03\n",
      "  1.04355272e-02  4.90792794e-03 -1.27791939e-02 -2.06226222e-02\n",
      " -2.18296852e-02  7.68736564e-03  2.90283523e-02  3.08515802e-02\n",
      "  1.01884315e-02  1.41542440e-03  1.95453819e-02  1.34634804e-02\n",
      "  2.05074474e-02  5.60368644e-03  5.59299998e-03  1.02447988e-02\n",
      "  1.50295598e-02  1.46457748e-02  1.88842174e-02  7.00653298e-03\n",
      "  1.17180850e-02  1.72173493e-02  4.94886748e-03  8.91567580e-03\n",
      "  8.86446331e-03  4.07411670e-03 -7.85611477e-03  6.26306981e-03\n",
      "  2.98483614e-02  2.20476221e-02  6.45708013e-03 -1.73305511e-03\n",
      " -8.36921763e-03  1.27537632e-02  3.50079015e-02  1.81897562e-02\n",
      "  5.50323399e-03  4.25501959e-03  1.21615492e-02  1.03994720e-02\n",
      "  1.37493787e-02  6.36998564e-03  1.12198165e-03  8.18203483e-03\n",
      "  9.96159203e-03  1.26122171e-02  2.76903287e-02  1.49028981e-02\n",
      "  8.35114066e-03  2.01347712e-02  1.60656963e-02  1.00508612e-02\n",
      "  3.92063661e-03 -3.08877463e-03 -4.28652717e-03  3.88022847e-02\n",
      "  1.47320926e-02  1.38088586e-02  2.04025418e-03 -6.06773049e-03\n",
      " -2.20272108e-03  8.42715055e-03  2.30567474e-02  1.59115642e-02\n",
      "  4.40632086e-03  6.28947862e-04  8.67252611e-03  2.50084000e-03\n",
      "  1.05789572e-03  5.27262175e-03  9.27112158e-03  1.47027392e-02\n",
      "  2.03473065e-02  1.04838228e-02  3.45936753e-02  1.15300072e-02\n",
      "  1.17647303e-02  1.72595959e-02  1.80222429e-02  1.54447062e-02\n",
      "  2.32500732e-02  8.19552317e-03  2.29843184e-02  4.27465849e-02\n",
      " -2.95497361e-03 -7.72496976e-04 -1.42707033e-02 -1.77589860e-02\n",
      " -2.76548485e-03  6.00193255e-03  1.82977803e-02  5.91139914e-03\n",
      " -5.05407155e-03 -4.24689194e-03  1.02214620e-03  7.35726906e-04\n",
      " -3.99989309e-03  6.44025113e-03  2.03928500e-02  1.07727181e-02\n",
      "  1.47489663e-02  2.01325230e-02  3.08272131e-02  1.87509488e-02\n",
      "  1.31416908e-02  2.16146037e-02  2.48543266e-02  1.95318274e-02\n",
      "  1.90912019e-02  1.15282759e-02  1.72114074e-02  1.75281093e-02\n",
      " -4.34914342e-04  6.49112801e-04  2.17957981e-03 -9.66576114e-03\n",
      "  7.31824338e-03  1.40350228e-02  2.39299815e-02  1.09695103e-02\n",
      " -9.12690442e-03 -7.38857826e-03  1.82140979e-03  5.43886842e-03\n",
      " -1.34857828e-02  9.14262794e-03  1.52161345e-02  3.74294212e-03\n",
      "  1.72984730e-02  2.32707839e-02  2.06036624e-02  2.52944101e-02\n",
      "  1.84361860e-02  2.30238661e-02  2.64812056e-02  2.16659699e-02\n",
      "  1.88271310e-02  2.49353610e-02  7.91690592e-03  1.28865913e-02\n",
      "  9.83919017e-03  1.89277809e-03 -1.29288423e-03 -1.91338325e-03\n",
      "  6.95666205e-03  1.22104548e-02  2.29894146e-02  1.90817639e-02\n",
      " -4.40158090e-03 -2.65659275e-03  7.84028042e-03  6.04661996e-04\n",
      " -1.19264470e-02  1.02456398e-02  1.24818599e-02  2.29945639e-03\n",
      "  3.58873489e-03  1.58448108e-02  8.64929892e-03  2.22379118e-02\n",
      "  2.18818486e-02  3.26141864e-02  3.26328799e-02  2.39232313e-02\n",
      "  7.94191100e-03  1.19983032e-02  3.13847908e-03  7.41759222e-03\n",
      "  9.85473767e-03 -3.15799913e-03 -5.38659748e-03 -8.41633137e-03\n",
      "  6.58888696e-03  1.09175975e-02  2.26549860e-02  2.52482891e-02\n",
      " -6.36749668e-03 -9.80483741e-03 -2.72569068e-05 -1.06364097e-02\n",
      " -1.72995217e-02 -8.86679045e-04 -1.41735151e-02 -1.48038333e-02\n",
      " -1.61951147e-02 -2.16940721e-03 -1.70281134e-03  1.95160024e-02\n",
      "  2.08090469e-02  3.37962434e-02  2.84436606e-02  2.00534537e-02\n",
      "  8.10593739e-03  3.58286384e-03 -3.47118825e-03  9.60062910e-03\n",
      " -5.99575287e-04 -1.42719625e-02 -6.29611174e-03 -3.69128888e-03\n",
      "  6.98779291e-03  1.21918237e-02  1.45923551e-02  1.79849006e-02\n",
      "  2.11910973e-03 -1.12883020e-02 -3.98944318e-03 -1.71009172e-02\n",
      " -1.38483718e-02 -1.03464490e-02 -1.10036451e-02 -4.15757531e-03\n",
      " -1.51342321e-02  6.59712451e-03  1.65977795e-02  2.09589824e-02\n",
      "  1.18348990e-02  1.68577153e-02  1.51924137e-02  1.15159182e-02\n",
      "  8.59583262e-03 -3.48898838e-03 -1.26731191e-02  1.11808861e-02\n",
      " -1.00415631e-03 -1.11097796e-02 -1.42033317e-03  1.10455193e-02\n",
      "  2.39075869e-02  2.55473144e-02  1.91194322e-02  1.42413601e-02\n",
      " -6.02416834e-03 -9.29465704e-03 -1.03793573e-02 -2.37691570e-02\n",
      " -2.41956767e-02  1.23394364e-02  4.04320331e-03  4.97306697e-03\n",
      " -1.69195943e-02 -3.30131454e-03  2.38798298e-02  1.03203543e-02\n",
      "  7.96440523e-03  3.03245131e-02  1.80220809e-02  2.07187794e-02\n",
      "  1.00715552e-02 -5.75410854e-03 -1.53732589e-02 -1.23584168e-02\n",
      " -1.07629569e-02 -5.15339756e-03 -6.04680413e-03  4.15101275e-03\n",
      "  2.00565215e-02  1.49135971e-02  1.50412302e-02  7.48063577e-03\n",
      " -1.58211552e-02 -9.66805778e-03 -2.35544313e-02 -2.59678252e-02\n",
      " -1.61386244e-02  1.07258204e-02  3.82859539e-03 -5.92057419e-04\n",
      " -5.57956146e-03  2.61169113e-03  1.65661722e-02  3.74642015e-03\n",
      "  8.58489145e-03  3.03768478e-02  1.53019046e-02  1.03019867e-02\n",
      "  3.12992197e-04 -1.12691149e-02 -2.20677238e-02 -3.09290122e-02\n",
      " -1.70237850e-02 -1.91039573e-02 -8.43686424e-03 -7.40929041e-03\n",
      "  3.56637058e-03  9.85611114e-04  1.20396083e-02  4.11019055e-03\n",
      " -2.41639949e-02 -2.07810216e-02 -1.24820014e-02 -1.01963999e-02\n",
      "  6.55351765e-03  2.60825604e-02  3.44907516e-03  4.96170996e-03\n",
      " -1.30345663e-02 -1.31515311e-02  1.66644901e-02  2.16905214e-02\n",
      "  1.26907360e-02  2.94462182e-02  1.22296717e-02  1.66509598e-02\n",
      "  1.22202560e-02 -3.82068451e-03 -1.59763694e-02 -2.55850852e-02\n",
      " -2.51963101e-02 -3.14252637e-02 -1.79041475e-02 -1.35027450e-02\n",
      " -5.96027635e-03 -1.08394457e-03 -9.06872656e-03 -9.99130495e-03\n",
      " -1.71328969e-02  2.71130982e-03 -8.54783878e-03 -1.99966617e-02\n",
      " -1.32298897e-04  1.13134012e-02 -1.42069999e-03  6.09499495e-03\n",
      " -6.20217947e-03 -5.87947387e-03  1.04708737e-02  2.73785535e-02\n",
      "  1.15972254e-02  2.83443872e-02  1.17514599e-02  1.56214004e-02\n",
      "  5.76471537e-03 -4.49284445e-03 -1.56441629e-02 -2.02929489e-02\n",
      " -3.56063209e-02 -3.05611212e-02 -1.24331368e-02 -1.29128229e-02\n",
      " -3.24089848e-03  1.45510770e-03 -1.85387419e-03 -1.12318108e-03\n",
      " -3.94390104e-03 -6.45412307e-04 -7.09560979e-03 -1.73076075e-02\n",
      " -1.01463608e-02  1.08109219e-02 -1.36665627e-03  1.05844373e-02\n",
      " -1.05727687e-02 -1.66294561e-03  2.43370086e-02  2.56550945e-02\n",
      "  1.20375091e-02  1.77099835e-02  4.27499227e-03  1.09973811e-02\n",
      "  4.62335674e-03 -2.79790000e-03 -7.55562168e-03 -2.47008540e-02\n",
      " -3.89256030e-02 -5.94134629e-03  2.00134721e-02  1.14927776e-02\n",
      "  1.59125105e-02  1.74571723e-02  1.99880432e-02  1.39567014e-02\n",
      "  2.39219167e-03  1.07850339e-02  3.30125098e-03  1.06469644e-02\n",
      "  9.64021217e-03  1.72360297e-02  3.27369384e-03  1.20961592e-02\n",
      "  1.32001226e-03  4.65136254e-03  2.73462608e-02  2.23633293e-02\n",
      "  1.72815826e-02  2.40312461e-02  1.90236736e-02  2.74667758e-02\n",
      "  1.79167539e-02  1.45224845e-02  2.52217194e-03 -1.40289366e-02\n",
      " -9.49648954e-03  1.67570896e-02  2.48924308e-02  1.48510197e-02\n",
      "  2.34984756e-02  1.84107590e-02  2.29204185e-02  1.09700449e-02\n",
      "  4.10654163e-03  8.00897181e-03 -1.14461277e-02  7.69499596e-03\n",
      "  1.85023565e-02  1.62013453e-02 -3.62084783e-03  1.24567933e-02\n",
      "  1.07568558e-02  7.87689630e-03  2.42281985e-02  1.92865636e-02\n",
      "  1.68159809e-02  2.73622889e-02  2.07526591e-02  2.78006960e-02\n",
      "  2.50637215e-02  1.67445894e-02  1.85114460e-03 -4.79112007e-03\n",
      "  2.43378561e-02  3.55292112e-02  3.01191863e-02  2.29799319e-02\n",
      "  2.80635692e-02  2.37510670e-02  2.91995443e-02  1.13418540e-02\n",
      "  1.04357377e-02  2.38615759e-02  6.22076448e-03  1.47007788e-02\n",
      "  1.60328764e-02  1.83605179e-02 -1.57727581e-03  1.94946900e-02\n",
      "  1.44812549e-02 -5.07370615e-03  1.08086942e-02  1.82577241e-02\n",
      "  1.46387778e-02  2.64895894e-02  1.64483953e-02  2.15900037e-02\n",
      "  1.58854071e-02  1.51738552e-02 -1.99169246e-03 -4.91505256e-03\n",
      "  2.80089509e-02  2.77092140e-02  3.24208699e-02  3.95269059e-02\n",
      "  4.25007381e-02  4.10222039e-02  4.46563102e-02  1.30418763e-02\n",
      "  1.28516173e-02  2.71496195e-02  1.52195431e-02  2.04199348e-02\n",
      "  1.69910826e-02  1.71835441e-02  9.97835863e-03  3.35642099e-02\n",
      "  2.69292276e-02  1.26692979e-02  1.97466370e-02  2.51621660e-02\n",
      "  1.88389085e-02  3.04785538e-02  2.36136280e-02  2.88658626e-02\n",
      "  2.84945965e-02  3.01410630e-02  1.16217211e-02  1.42262438e-02\n",
      "  5.12350956e-03  1.88326240e-02  2.31008548e-02  3.87077369e-02\n",
      "  3.76534350e-02  3.76285426e-02  3.99692208e-02  2.23133415e-02\n",
      "  1.92212407e-02  2.98664048e-02  2.12706029e-02  3.03554870e-02\n",
      "  3.55004743e-02  2.67215148e-02  1.98274255e-02  3.02497726e-02\n",
      "  1.96499508e-02  1.38431573e-02  2.84225866e-02  3.63588557e-02\n",
      "  3.40980887e-02  3.81107740e-02  3.63259465e-02  3.89332250e-02\n",
      "  3.83234695e-02  3.34627703e-02  1.87947657e-02  1.99926253e-02\n",
      "  1.72581887e-04  1.53934266e-02  1.20200645e-02  2.40869746e-02\n",
      "  3.02629508e-02  3.50647792e-02  4.34984416e-02  1.87423415e-02\n",
      "  2.35664584e-02  2.95071322e-02  1.58030707e-02  2.21294593e-02\n",
      "  2.37706807e-02  2.05446742e-02  1.35012148e-02  2.82926541e-02\n",
      "  1.78736579e-02  1.79263037e-02  2.62293741e-02  3.77412587e-02\n",
      "  3.41196395e-02  3.92291248e-02  3.65459211e-02  4.30990458e-02\n",
      "  4.37668152e-02  2.99241785e-02  1.74579471e-02  2.02124403e-03\n",
      "  6.42167777e-03  1.18750264e-03 -1.11572887e-03  2.40669940e-02\n",
      "  3.19256037e-02  3.48247364e-02  3.97264324e-02  8.61376338e-03\n",
      "  2.78050732e-02  4.28961366e-02  2.06672810e-02  2.12216526e-02\n",
      "  2.26726066e-02  2.28803121e-02  1.50749302e-02  2.55984105e-02\n",
      "  2.39844974e-02  1.78811979e-02  2.81970073e-02  4.05967981e-02\n",
      "  3.61265726e-02  3.42770703e-02  2.72472613e-02  3.08909118e-02\n",
      "  3.59589793e-02  1.85448043e-02  2.03878898e-03 -8.33288953e-03\n",
      " -6.80755312e-03  5.90216042e-03  4.84660501e-03  2.60295924e-02\n",
      "  2.58698370e-02  3.34796458e-02  2.79022995e-02  2.14220677e-03\n",
      "  1.99755952e-02  2.78575700e-02  1.09758209e-02  1.44785084e-02\n",
      "  1.99851617e-02  1.33757759e-02  4.93340567e-03  2.91063245e-02\n",
      "  2.40062922e-02  1.40671898e-02  2.79013030e-02  3.80676314e-02\n",
      "  3.95431705e-02  2.27731597e-02  2.08142437e-02  2.61080209e-02\n",
      "  3.65753546e-02  2.24250723e-02  1.71915395e-03 -2.54349574e-03\n",
      " -1.88358948e-02 -4.83348733e-03  4.89943707e-03  3.38398777e-02\n",
      "  3.06707025e-02  2.99117137e-02  2.78845094e-02  3.66679952e-03\n",
      "  1.32801682e-02  2.13778820e-02  5.27086528e-03  1.03532430e-02\n",
      "  9.75729153e-03  4.41166107e-03  1.35788193e-03  1.98175143e-02\n",
      "  1.44739710e-02  6.88119046e-03  1.69988424e-02  3.06651928e-02\n",
      "  3.05812787e-02  1.71711333e-02  2.72734575e-02  3.22281271e-02\n",
      "  3.65212373e-02  2.23197434e-02  1.15554994e-02  1.09978812e-02\n",
      "  1.00692492e-02  1.91419269e-03 -1.53961172e-03  3.02008409e-02\n",
      "  3.26466858e-02  3.44693437e-02  2.71769017e-02  3.07245390e-03\n",
      " -1.38861826e-03  9.84916370e-03 -5.11447666e-03 -3.41302203e-03\n",
      "  9.93191265e-04 -7.13662850e-03 -8.33349116e-03  8.69519357e-03\n",
      "  4.85636760e-03 -8.41387361e-03 -3.20522278e-03  9.90457926e-03\n",
      "  9.88449063e-03  1.16746910e-02  1.98120940e-02  2.66295653e-02\n",
      "  2.34886818e-02  4.95604426e-03 -2.31104949e-03 -8.09284858e-03\n",
      "  1.06914397e-02  1.58955418e-02  2.28102449e-02  3.38844284e-02\n",
      "  4.22481112e-02  4.30888534e-02  2.75083221e-02  1.28283957e-02\n",
      "  9.26356018e-03  5.10467682e-03 -5.06989844e-03 -3.65820969e-03\n",
      " -5.37981698e-03 -2.26781778e-02 -2.54808310e-02 -2.83276872e-03\n",
      " -6.95864065e-03 -1.92377493e-02 -7.20793800e-03  4.67667216e-03\n",
      "  1.32660307e-02  1.87259074e-02  1.55533925e-02  3.16411480e-02\n",
      "  1.84662919e-02  7.44097307e-03 -9.36161447e-03 -6.52152812e-05] [0.34163132 2.8825371  4.7776055  1.202886   1.3719418  1.1060272\n",
      " 1.0675011  0.9624039  0.92120767 0.99875206 0.9915094  0.9909271\n",
      " 1.0004141  1.01025    1.0072502  1.0052174  0.9949889  0.9976152\n",
      " 0.9906367  1.0347495  1.0017688  0.95344895 1.0267622  1.0648301\n",
      " 1.0212268  1.3716229  1.2784404  2.0379047  0.9515884  0.7838208\n",
      " 1.5207385  0.8608255  0.9014254  0.97698003 0.91502947 0.9507783\n",
      " 0.9856861  0.99907273 1.0016283  1.002395   1.0019016  1.0042993\n",
      " 1.0039384  0.9997358  0.99494874 0.998209   0.99999255 1.0033854\n",
      " 0.99522126 1.0142117  1.0568209  1.0881443  1.0700699  1.1283139\n",
      " 1.0615621  1.4781423  1.2601624  1.243427   0.96709603 0.7528477\n",
      " 0.9132843  0.9005354  0.9644804  1.0005981  0.9984191  0.99588513\n",
      " 0.99508744 0.99810743 0.9990183  0.999021   0.9980292  0.9982056\n",
      " 0.9994361  0.9955613  1.0013857  1.0004328  0.9933739  0.9904939\n",
      " 1.0143939  0.99011385 1.0203717  1.055523   1.0803947  1.078476\n",
      " 0.6503762  1.1050519  0.73130983 0.82767886 0.92774624 0.9536957\n",
      " 1.0018642  1.0003773  0.9976638  0.999025   1.0013348  0.9948815\n",
      " 0.99661344 0.9963912  1.0021583  0.99819094 1.0004979  0.99705356\n",
      " 0.9950517  0.9954362  0.99533063 0.99618286 0.99438167 0.9987807\n",
      " 1.012667   1.0433689  1.0555096  1.0482459  1.0827563  1.0907633\n",
      " 0.91230845 0.8949051  0.93553704 1.0013602  1.0064825  1.0143136\n",
      " 0.99897355 0.99309903 0.9990474  1.00034    1.0015942  0.99945056\n",
      " 0.99610174 0.9964917  1.0028068  0.9985037  0.99706376 0.9913571\n",
      " 1.0070906  1.0050577  0.99925005 1.0121185  1.0120078  1.0086763\n",
      " 0.95488536 1.0019596  1.4736388  1.1827079  1.0439031  0.98412746\n",
      " 0.9739931  1.0035201  1.0096328  1.0054417  0.9972069  0.99335074\n",
      " 0.9966291  0.99869776 0.9979613  1.0016032  0.99531627 0.993832\n",
      " 0.997044   0.9898819  0.9925112  0.99151665 0.9958746  0.9978689\n",
      " 0.9998209  1.0046617  0.9895568  0.9858321  0.9703126  1.2406021\n",
      " 1.1502498  1.0925207  1.037851   0.9961246  0.9815     0.99611944\n",
      " 1.0084759  1.0033221  0.99906886 0.99333143 0.9975484  0.99942446\n",
      " 0.99888337 1.000092   0.9965196  0.99617374 0.99684125 0.998663\n",
      " 0.98889124 0.98897105 0.9973001  0.99785393 1.0031903  1.0141424\n",
      " 1.0278885  1.0172311  1.0488944  1.1750215  1.0275612  1.008323\n",
      " 0.94973004 0.9669768  0.9777952  0.9902357  1.0011364  0.99778044\n",
      " 0.991041   0.9900391  0.99596804 1.0017391  0.99400795 1.0010507\n",
      " 0.999834   0.9959925  0.9904248  0.990696   0.98514503 0.98585004\n",
      " 0.9933993  0.9950062  0.9997383  1.0091287  1.017258   1.0190196\n",
      " 1.021001   1.0338066  1.0462328  1.0032539  0.9919904  0.9733438\n",
      " 0.9927919  0.9932169  1.0013976  0.9939849  0.98927504 0.98624384\n",
      " 0.991564   0.9958591  0.98622084 0.9968191  0.9958584  0.9974931\n",
      " 0.99118006 0.98389524 0.980924   0.98690945 0.98922473 0.99505156\n",
      " 1.0003613  1.0058973  1.0136775  1.0225924  1.0031719  1.0120711\n",
      " 1.0925151  1.0173889  0.981678   0.9810188  0.9882252  0.9884052\n",
      " 1.0013096  0.99478865 0.9908341  0.99002403 0.9903928  0.9905936\n",
      " 0.9890949  0.99675024 0.99220973 0.99085516 0.98647493 0.98240936\n",
      " 0.97594064 0.9841489  0.9916002  0.9955577  0.9968072  0.99969864\n",
      " 1.0046641  1.0079535  0.99871045 0.99183834 1.0862899  0.99989575\n",
      " 0.9790788  0.9670478  0.9876228  0.9878882  0.99765265 1.0023518\n",
      " 0.9857786  0.9902166  0.9940879  0.9969275  0.99236655 0.9925676\n",
      " 0.99517447 0.9994604  1.0026275  0.9836306  0.98323274 0.98611635\n",
      " 0.98954946 0.9965319  0.99785143 0.9992668  0.9986054  1.0011379\n",
      " 0.9850822  1.0103489  0.9862823  0.9665248  0.9699662  0.9771526\n",
      " 0.9890818  0.9861166  0.99292016 0.9990098  0.99099094 0.99139005\n",
      " 0.99472106 0.9942713  0.995663   1.0026093  0.9935571  0.99079776\n",
      " 1.0029157  0.98958737 0.97808915 0.9799537  0.9835831  0.98724025\n",
      " 0.9899151  0.991381   1.0010345  1.0047442  0.98402643 1.0353156\n",
      " 1.0000331  0.96954143 0.991286   0.9970102  1.0107083  0.99851936\n",
      " 0.9937118  0.9996297  0.9861895  0.99897665 0.99038064 1.0061773\n",
      " 1.0049505  0.9974924  0.996812   0.99905515 1.0007389  0.9964219\n",
      " 0.97508186 0.9853236  0.9890863  0.99129725 0.9977917  1.0007589\n",
      " 1.0052241  0.99739134 0.9872238  0.97594184 0.964986   0.98852557\n",
      " 0.9722185  0.9878948  1.0040526  0.993484   0.9978114  0.9972641\n",
      " 0.9853417  1.0020609  1.0010418  0.99801344 0.9983169  0.9933866\n",
      " 0.9934877  0.9918264  0.995294   0.9906158  0.9785699  0.9813419\n",
      " 0.9885428  0.99330914 0.9968937  0.99244803 0.9957053  0.9921149\n",
      " 0.9730557  0.9440634  0.96058136 0.95145935 0.9861783  0.97903985\n",
      " 0.99227667 0.99070215 1.002114   1.0026731  0.9875257  1.0033667\n",
      " 0.9946808  0.995448   0.99193466 0.99022216 0.98919255 0.9915422\n",
      " 1.0009762  1.0014113  0.9843424  0.9830724  0.9909889  0.99151665\n",
      " 0.9913896  0.9898687  0.99582016 0.98871076 0.9832403  0.95171726\n",
      " 0.9483449  0.96141857 0.9755412  0.9720953  0.9885789  0.9946544\n",
      " 0.9966124  1.0013845  0.98855233 0.9937451  1.0022899  0.99852186\n",
      " 0.9945291  0.9958089  0.9903385  0.9893972  0.99482304 0.991434\n",
      " 0.9890409  0.97764885 0.9950104  0.9851302  0.9924433  0.9880627\n",
      " 0.993642   0.9922749  0.98752713 0.9566819  0.93808866 0.9668311\n",
      " 0.9843035  0.9785005  0.9898536  0.99343824 0.9991571  0.99333763\n",
      " 0.9917627  0.9961083  0.99476165 0.9958428  0.9879223  0.9841462\n",
      " 0.987889   0.98777926 0.99704355 0.980869   0.98158824 0.9787803\n",
      " 0.9912546  0.97897387 0.99146366 0.989967   0.98968095 0.99339294\n",
      " 0.99461323 0.9634232  0.912578   0.99356216 1.020331   0.9987693\n",
      " 0.9996603  0.9926737  0.9974039  0.99489194 0.989417   0.99273336\n",
      " 0.98712146 0.9759693  0.9744912  0.977473   0.98643434 0.98138195\n",
      " 0.9880261  0.98495233 0.9849452  0.9814122  0.9946362  0.9892278\n",
      " 0.9930743  0.9966928  0.99831676 1.0073411  1.0063328  0.9714564\n",
      " 0.97804964 1.0290312  1.0305907  0.9993147  1.0044451  0.9961607\n",
      " 0.99966925 0.9983937  0.9893737  1.0013511  1.0019053  0.9904576\n",
      " 0.99015963 0.98535615 0.9934253  0.99416286 0.9895963  0.9864537\n",
      " 0.9909714  0.9835283  0.99958533 0.995615   0.99505514 1.0004125\n",
      " 1.0011015  1.0080009  1.0045375  1.0137721  1.0463296  1.0537454\n",
      " 1.0427424  1.0070938  1.0052596  0.9957208  0.9993689  0.99711025\n",
      " 0.99186665 0.9912989  0.98906314 0.9836053  0.9890593  0.9890001\n",
      " 0.99110913 0.9804414  0.9826679  0.98966336 0.9921642  0.9847602\n",
      " 0.9983602  0.9981995  0.9931545  0.99940467 0.9961885  1.0044537\n",
      " 0.9937532  1.0123914  1.0453718  1.0238068  1.0292747  1.0116467\n",
      " 1.007397   1.0042015  1.0027792  0.9936927  0.9933813  0.99279535\n",
      " 0.98719627 0.98071307 0.9886545  0.98436934 0.9883705  0.9778528\n",
      " 0.9757723  0.98008686 0.98964185 0.9831205  0.99599886 0.9974546\n",
      " 0.9942136  1.0025455  1.0050485  1.0156956  1.0082211  1.021867\n",
      " 0.9904788  1.0198907  1.0156572  1.0105981  1.0055224  0.99868023\n",
      " 0.993994   0.98693055 0.987966   0.98614806 0.97610205 0.97396815\n",
      " 0.97437346 0.9809658  0.97821635 0.9752058  0.9785474  0.97899896\n",
      " 0.9935334  0.98939043 0.9934195  0.99442536 0.9965002  1.0013795\n",
      " 1.003862   1.013178   1.0120479  1.0297183  0.962481   1.0080234\n",
      " 1.006301   1.0028533  1.0048118  1.0020101  1.0048766  0.9892346\n",
      " 0.99400145 0.98803973 0.9836311  0.9835823  0.98081636 0.9848241\n",
      " 0.9838497  0.98558074 0.9874505  0.9843277  0.9940389  0.99536103\n",
      " 0.99625003 1.0011921  1.0029663  1.0110984  1.0142927  1.0126976\n",
      " 1.0116532  0.99614227 1.0205554  0.991259   0.9850687  1.0066237\n",
      " 1.0114795  1.009902   1.0108314  0.9867525  0.9965512  0.9968671\n",
      " 0.9902945  0.99204487 0.9898899  0.9931242  0.996868   0.9962326\n",
      " 0.9967361  0.9907822  1.001823   1.0023135  1.002914   1.0084324\n",
      " 1.0012441  1.0067382  1.0099068  1.00208    0.98931986 0.94887984\n",
      " 0.9283885  1.0023401  0.99384695 1.0144958  1.0066131  1.006102\n",
      " 1.0043875  0.98678595 0.9958674  0.9971743  0.9914875  0.9899426\n",
      " 0.9896472  0.991268   0.98924166 0.99697375 0.99433744 0.98820686\n",
      " 1.000378   1.0027047  1.0047442  1.0014641  1.0044034  1.010077\n",
      " 1.019073   1.013722   0.9974459  0.9494389  0.7974798  0.939397\n",
      " 0.98772675 1.0311117  1.0189203  1.0086137  1.0075539  0.9967396\n",
      " 0.99439937 0.9982868  0.99749225 0.99630207 0.99494445 0.9943224\n",
      " 0.99443907 0.9990177  0.9975398  0.99229527 1.0024681  1.0045978\n",
      " 1.0058527  1.0035243  1.0120276  1.0186795  1.0288111  1.0168568\n",
      " 1.0197316  0.9930617  1.0945818  0.99559486 0.9831154  1.0252088\n",
      " 1.025227   1.0200776  1.0145788  0.9946202  0.9922832  1.0014904\n",
      " 0.9988159  0.99633896 0.99668795 0.9928971  0.9902243  0.99982125\n",
      " 0.9991123  0.98997027 0.9949894  0.9964389  0.99674296 0.9951232\n",
      " 1.0043468  1.0157614  1.0187147  0.9971744  0.9804649  0.8964888\n",
      " 1.7638171  1.1682385  1.1702031  1.0953912  1.0638372  1.0500901\n",
      " 1.021507   1.0122644  0.99934787 0.996308   0.99086064 0.99566996\n",
      " 0.99485683 0.98841107 0.9855003  0.99762714 1.0001235  0.9898178\n",
      " 0.99302685 0.99238765 1.0117711  1.0130814  1.0085759  1.0431529\n",
      " 1.0290791  1.0101016  0.90654343 1.1259955 ]\n",
      "Train data shape:  (4000, 784)\n",
      "Train labels shape:  (4000, 10)\n",
      "Test data shape:  (6000, 784)\n",
      "Test labels shape:  (6000, 10)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from PIL import Image\n",
    "import tensorflow as tf\n",
    "\n",
    "def one_hot_encode(y, num_classes):\n",
    "    return np.eye(num_classes)[y]\n",
    "\n",
    "def preprocess_line(line):\n",
    "    # Split the line by commas to handle comma-separated values\n",
    "    preprocessed = line.strip().split(',')\n",
    "    return preprocessed\n",
    "\n",
    "def load_data(fashion=False, digit=None, normalize=False, one_hot=True):\n",
    "    if fashion:\n",
    "        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n",
    "    else:\n",
    "        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "    if digit is not None and 0 <= digit <= 9:\n",
    "        train = test = {y: [] for y in range(10)}\n",
    "        for x, y in zip(x_train, y_train):\n",
    "            train[y].append(x)\n",
    "        for x, y in zip(x_test, y_test):\n",
    "            test[y].append(x)\n",
    "\n",
    "        for y in range(10):\n",
    "            train[y] = np.asarray(train[y])\n",
    "            test[y] = np.asarray(test[y])\n",
    "\n",
    "        x_train = train[digit]\n",
    "        x_test = test[digit]\n",
    "\n",
    "    x_train = x_train.reshape((-1, x_train.shape[1] * x_train.shape[2])).astype(np.float32)\n",
    "    x_test = x_test.reshape((-1, x_test.shape[1] * x_test.shape[2])).astype(np.float32)\n",
    "\n",
    "    if normalize:\n",
    "        # normalize data\n",
    "        scaler = StandardScaler().fit(x_train)\n",
    "        x_train = scaler.transform(x_train)\n",
    "        x_test = scaler.transform(x_test)\n",
    "\n",
    "    num_classes = np.unique(y_train).shape[0]\n",
    "\n",
    "    if one_hot:\n",
    "        y_train = one_hot_encode(y_train, num_classes)\n",
    "        y_test = one_hot_encode(y_test, num_classes)\n",
    "\n",
    "    print(\"x_train shape: {}, y_train shape: {}\".format(x_train.shape, y_train.shape))\n",
    "    print(\"x_test shape: {}, y_test shape: {}\".format(x_test.shape, y_test.shape))\n",
    "\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "\n",
    "\n",
    "def load_mnist(one_hot=True):\n",
    "    train, test = load_data(fashion = False, normalize = True, one_hot=one_hot)\n",
    "    \n",
    "    # Set seed\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size = 0.6)\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "\n",
    "def load_fashion(one_hot=True):\n",
    "    train, test = load_data(fashion = True, normalize = True, one_hot=one_hot)\n",
    "    \n",
    "    # Set seed\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    x_train, x_test, y_train, y_test = train_test_split(test[0], test[1], test_size = 0.6)\n",
    "    return (x_train, y_train), (x_test, y_test)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_fashion()\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-01-19T18:41:48.512816Z",
     "iopub.status.busy": "2025-01-19T18:41:48.512318Z",
     "iopub.status.idle": "2025-01-19T18:41:55.053310Z",
     "shell.execute_reply": "2025-01-19T18:41:55.052792Z",
     "shell.execute_reply.started": "2025-01-19T18:41:48.512794Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"VanillaLeNet300100\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_1 (InputLayer)        [(None, 784)]             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 300)               235500    \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 100)               30100     \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 10)                1010      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 266,610\n",
      "Trainable params: 266,610\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "\n",
      "Epoch 1: Current learning rate = 1.500e-01\n",
      "Epoch 1/100\n",
      "16/16 [==============================] - 1s 2ms/step - loss: 1.8380 - accuracy: 0.5180\n",
      "\n",
      "Epoch 2: Current learning rate = 1.500e-01\n",
      "Epoch 2/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 1.0922 - accuracy: 0.6528\n",
      "\n",
      "Epoch 3: Current learning rate = 1.498e-01\n",
      "Epoch 3/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.8719 - accuracy: 0.7147\n",
      "\n",
      "Epoch 4: Current learning rate = 1.496e-01\n",
      "Epoch 4/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.7682 - accuracy: 0.7577\n",
      "\n",
      "Epoch 5: Current learning rate = 1.493e-01\n",
      "Epoch 5/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.6005 - accuracy: 0.7960\n",
      "\n",
      "Epoch 6: Current learning rate = 1.489e-01\n",
      "Epoch 6/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.6377 - accuracy: 0.7720\n",
      "\n",
      "Epoch 7: Current learning rate = 1.485e-01\n",
      "Epoch 7/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.4826 - accuracy: 0.8210\n",
      "\n",
      "Epoch 8: Current learning rate = 1.479e-01\n",
      "Epoch 8/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.3943 - accuracy: 0.8620\n",
      "\n",
      "Epoch 9: Current learning rate = 1.473e-01\n",
      "Epoch 9/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.3024 - accuracy: 0.8915\n",
      "\n",
      "Epoch 10: Current learning rate = 1.466e-01\n",
      "Epoch 10/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.2727 - accuracy: 0.9028\n",
      "\n",
      "Epoch 11: Current learning rate = 1.458e-01\n",
      "Epoch 11/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.2106 - accuracy: 0.9247\n",
      "\n",
      "Epoch 12: Current learning rate = 1.450e-01\n",
      "Epoch 12/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.2196 - accuracy: 0.9325\n",
      "\n",
      "Epoch 13: Current learning rate = 1.440e-01\n",
      "Epoch 13/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1921 - accuracy: 0.9327\n",
      "\n",
      "Epoch 14: Current learning rate = 1.430e-01\n",
      "Epoch 14/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1612 - accuracy: 0.9475\n",
      "\n",
      "Epoch 15: Current learning rate = 1.419e-01\n",
      "Epoch 15/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1351 - accuracy: 0.9560\n",
      "\n",
      "Epoch 16: Current learning rate = 1.407e-01\n",
      "Epoch 16/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0985 - accuracy: 0.9705\n",
      "\n",
      "Epoch 17: Current learning rate = 1.395e-01\n",
      "Epoch 17/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0909 - accuracy: 0.9712\n",
      "\n",
      "Epoch 18: Current learning rate = 1.382e-01\n",
      "Epoch 18/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0876 - accuracy: 0.9728\n",
      "\n",
      "Epoch 19: Current learning rate = 1.368e-01\n",
      "Epoch 19/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0830 - accuracy: 0.9743\n",
      "\n",
      "Epoch 20: Current learning rate = 1.353e-01\n",
      "Epoch 20/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1185 - accuracy: 0.9650\n",
      "\n",
      "Epoch 21: Current learning rate = 1.338e-01\n",
      "Epoch 21/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1068 - accuracy: 0.9695\n",
      "\n",
      "Epoch 22: Current learning rate = 1.322e-01\n",
      "Epoch 22/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.3413 - accuracy: 0.9595\n",
      "\n",
      "Epoch 23: Current learning rate = 1.305e-01\n",
      "Epoch 23/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0999 - accuracy: 0.9663\n",
      "\n",
      "Epoch 24: Current learning rate = 1.288e-01\n",
      "Epoch 24/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.0864 - accuracy: 0.9745\n",
      "\n",
      "Epoch 25: Current learning rate = 1.270e-01\n",
      "Epoch 25/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.1758 - accuracy: 0.9610\n",
      "\n",
      "Epoch 26: Current learning rate = 1.252e-01\n",
      "Epoch 26/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.3517 - accuracy: 0.9532\n",
      "\n",
      "Epoch 27: Current learning rate = 1.233e-01\n",
      "Epoch 27/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.4641 - accuracy: 0.9153\n",
      "\n",
      "Epoch 28: Current learning rate = 1.213e-01\n",
      "Epoch 28/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 0.8206 - accuracy: 0.9075\n",
      "\n",
      "Epoch 29: Current learning rate = 1.193e-01\n",
      "Epoch 29/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 1.1859 - accuracy: 0.8062\n",
      "\n",
      "Epoch 30: Current learning rate = 1.173e-01\n",
      "Epoch 30/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 18.5996 - accuracy: 0.4933\n",
      "\n",
      "Epoch 31: Current learning rate = 1.152e-01\n",
      "Epoch 31/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.5939 - accuracy: 0.1082\n",
      "\n",
      "Epoch 32: Current learning rate = 1.130e-01\n",
      "Epoch 32/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3277 - accuracy: 0.1165\n",
      "\n",
      "Epoch 33: Current learning rate = 1.109e-01\n",
      "Epoch 33/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.7591 - accuracy: 0.1075\n",
      "\n",
      "Epoch 34: Current learning rate = 1.086e-01\n",
      "Epoch 34/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3045 - accuracy: 0.0997\n",
      "\n",
      "Epoch 35: Current learning rate = 1.064e-01\n",
      "Epoch 35/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3020 - accuracy: 0.1023\n",
      "\n",
      "Epoch 36: Current learning rate = 1.041e-01\n",
      "Epoch 36/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3018 - accuracy: 0.1020\n",
      "\n",
      "Epoch 37: Current learning rate = 1.017e-01\n",
      "Epoch 37/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3016 - accuracy: 0.1007\n",
      "\n",
      "Epoch 38: Current learning rate = 9.937e-02\n",
      "Epoch 38/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3017 - accuracy: 0.1047\n",
      "\n",
      "Epoch 39: Current learning rate = 9.698e-02\n",
      "Epoch 39/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3016 - accuracy: 0.0985\n",
      "\n",
      "Epoch 40: Current learning rate = 9.456e-02\n",
      "Epoch 40/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3017 - accuracy: 0.1018\n",
      "\n",
      "Epoch 41: Current learning rate = 9.213e-02\n",
      "Epoch 41/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3017 - accuracy: 0.1063\n",
      "\n",
      "Epoch 42: Current learning rate = 8.967e-02\n",
      "Epoch 42/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3017 - accuracy: 0.1067\n",
      "\n",
      "Epoch 43: Current learning rate = 8.720e-02\n",
      "Epoch 43/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3015 - accuracy: 0.1067\n",
      "\n",
      "Epoch 44: Current learning rate = 8.471e-02\n",
      "Epoch 44/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3015 - accuracy: 0.1047\n",
      "\n",
      "Epoch 45: Current learning rate = 8.221e-02\n",
      "Epoch 45/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3016 - accuracy: 0.1047\n",
      "\n",
      "Epoch 46: Current learning rate = 7.971e-02\n",
      "Epoch 46/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3014 - accuracy: 0.1063\n",
      "\n",
      "Epoch 47: Current learning rate = 7.720e-02\n",
      "Epoch 47/100\n",
      "16/16 [==============================] - 0s 3ms/step - loss: 2.3015 - accuracy: 0.0983\n",
      "\n",
      "Epoch 48: Current learning rate = 7.469e-02\n",
      "Epoch 48/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3014 - accuracy: 0.1067\n",
      "\n",
      "Epoch 49: Current learning rate = 7.217e-02\n",
      "Epoch 49/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3014 - accuracy: 0.1067\n",
      "\n",
      "Epoch 50: Current learning rate = 6.966e-02\n",
      "Epoch 50/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3014 - accuracy: 0.1028\n",
      "\n",
      "Epoch 51: Current learning rate = 6.716e-02\n",
      "Epoch 51/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3014 - accuracy: 0.0988\n",
      "\n",
      "Epoch 52: Current learning rate = 6.467e-02\n",
      "Epoch 52/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3017 - accuracy: 0.1065\n",
      "\n",
      "Epoch 53: Current learning rate = 6.218e-02\n",
      "Epoch 53/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1065\n",
      "\n",
      "Epoch 54: Current learning rate = 5.971e-02\n",
      "Epoch 54/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3013 - accuracy: 0.1065\n",
      "\n",
      "Epoch 55: Current learning rate = 5.726e-02\n",
      "Epoch 55/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3015 - accuracy: 0.1065\n",
      "\n",
      "Epoch 56: Current learning rate = 5.483e-02\n",
      "Epoch 56/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3011 - accuracy: 0.1067\n",
      "\n",
      "Epoch 57: Current learning rate = 5.242e-02\n",
      "Epoch 57/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1067\n",
      "\n",
      "Epoch 58: Current learning rate = 5.004e-02\n",
      "Epoch 58/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.0993\n",
      "\n",
      "Epoch 59: Current learning rate = 4.768e-02\n",
      "Epoch 59/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3013 - accuracy: 0.1067\n",
      "\n",
      "Epoch 60: Current learning rate = 4.536e-02\n",
      "Epoch 60/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1067\n",
      "\n",
      "Epoch 61: Current learning rate = 4.307e-02\n",
      "Epoch 61/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3011 - accuracy: 0.1047\n",
      "\n",
      "Epoch 62: Current learning rate = 4.081e-02\n",
      "Epoch 62/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3011 - accuracy: 0.1065\n",
      "\n",
      "Epoch 63: Current learning rate = 3.859e-02\n",
      "Epoch 63/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1065\n",
      "\n",
      "Epoch 64: Current learning rate = 3.642e-02\n",
      "Epoch 64/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1007\n",
      "\n",
      "Epoch 65: Current learning rate = 3.428e-02\n",
      "Epoch 65/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1065\n",
      "\n",
      "Epoch 66: Current learning rate = 3.220e-02\n",
      "Epoch 66/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1050\n",
      "\n",
      "Epoch 67: Current learning rate = 3.016e-02\n",
      "Epoch 67/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3011 - accuracy: 0.1015\n",
      "\n",
      "Epoch 68: Current learning rate = 2.817e-02\n",
      "Epoch 68/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3012 - accuracy: 0.1065\n",
      "\n",
      "Epoch 69: Current learning rate = 2.623e-02\n",
      "Epoch 69/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1067\n",
      "\n",
      "Epoch 70: Current learning rate = 2.435e-02\n",
      "Epoch 70/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1067\n",
      "\n",
      "Epoch 71: Current learning rate = 2.253e-02\n",
      "Epoch 71/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1067\n",
      "\n",
      "Epoch 72: Current learning rate = 2.076e-02\n",
      "Epoch 72/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3011 - accuracy: 0.0990\n",
      "\n",
      "Epoch 73: Current learning rate = 1.905e-02\n",
      "Epoch 73/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1028\n",
      "\n",
      "Epoch 74: Current learning rate = 1.741e-02\n",
      "Epoch 74/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1067\n",
      "\n",
      "Epoch 75: Current learning rate = 1.583e-02\n",
      "Epoch 75/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.1067\n",
      "\n",
      "Epoch 76: Current learning rate = 1.432e-02\n",
      "Epoch 76/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 77: Current learning rate = 1.288e-02\n",
      "Epoch 77/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3010 - accuracy: 0.0988\n",
      "\n",
      "Epoch 78: Current learning rate = 1.151e-02\n",
      "Epoch 78/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 79: Current learning rate = 1.021e-02\n",
      "Epoch 79/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1032\n",
      "\n",
      "Epoch 80: Current learning rate = 8.977e-03\n",
      "Epoch 80/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1065\n",
      "\n",
      "Epoch 81: Current learning rate = 7.822e-03\n",
      "Epoch 81/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1065\n",
      "\n",
      "Epoch 82: Current learning rate = 6.742e-03\n",
      "Epoch 82/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1065\n",
      "\n",
      "Epoch 83: Current learning rate = 5.739e-03\n",
      "Epoch 83/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1053\n",
      "\n",
      "Epoch 84: Current learning rate = 4.814e-03\n",
      "Epoch 84/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1040\n",
      "\n",
      "Epoch 85: Current learning rate = 3.968e-03\n",
      "Epoch 85/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1010\n",
      "\n",
      "Epoch 86: Current learning rate = 3.201e-03\n",
      "Epoch 86/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1007\n",
      "\n",
      "Epoch 87: Current learning rate = 2.515e-03\n",
      "Epoch 87/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1018\n",
      "\n",
      "Epoch 88: Current learning rate = 1.910e-03\n",
      "Epoch 88/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 89: Current learning rate = 1.388e-03\n",
      "Epoch 89/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 90: Current learning rate = 9.481e-04\n",
      "Epoch 90/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 91: Current learning rate = 5.914e-04\n",
      "Epoch 91/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 92: Current learning rate = 3.182e-04\n",
      "Epoch 92/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 93: Current learning rate = 1.289e-04\n",
      "Epoch 93/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 94: Current learning rate = 2.368e-05\n",
      "Epoch 94/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 95: Current learning rate = 0.000e+00\n",
      "Epoch 95/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 96: Current learning rate = 0.000e+00\n",
      "Epoch 96/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 97: Current learning rate = 0.000e+00\n",
      "Epoch 97/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 98: Current learning rate = 0.000e+00\n",
      "Epoch 98/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 99: Current learning rate = 0.000e+00\n",
      "Epoch 99/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "\n",
      "Epoch 100: Current learning rate = 0.000e+00\n",
      "Epoch 100/100\n",
      "16/16 [==============================] - 0s 2ms/step - loss: 2.3009 - accuracy: 0.1067\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 2.3086 - accuracy: 0.0957\n",
      "\n",
      "Test loss 2.308608293533325\n",
      "Test accuracy 0.09566666930913925\n"
     ]
    }
   ],
   "source": [
    "# Vanilla LeNet-300-100 on fmnist\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_fmnist'\n",
    "#DEPTH = DEPTH\n",
    "LA = 0 #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR #INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal\n",
    "#INIT = tf.keras.initializers.HeUniform\n",
    "EPOCHS = EPOCHS\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100(input_shape=(X_train.shape[1],), n_classes = CLASS_NUM, la=LA, units1=300, units2=100)\n",
    "\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-19T18:42:13.143800Z",
     "iopub.status.busy": "2025-01-19T18:42:13.142960Z",
     "iopub.status.idle": "2025-01-19T19:28:36.037500Z",
     "shell.execute_reply": "2025-01-19T19:28:36.036850Z",
     "shell.execute_reply.started": "2025-01-19T18:42:13.143769Z"
    }
   },
   "outputs": [],
   "source": [
    "DEPTH_LIST = [2, 3, 4]\n",
    "REPS = 5\n",
    "BASE_SEED = SEED\n",
    "\n",
    "for depth in DEPTH_LIST:\n",
    "    for rep in range(REPS):\n",
    "        current_seed = BASE_SEED + rep\n",
    "        \n",
    "        for LA_ITER in LAMBDA_LIST:\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            MODEL = 'lenet300100_fmnist'  # Keep original dataset\n",
    "            DEPTH = depth  # Use the loop variable\n",
    "            LA = LA_ITER\n",
    "            print(f'Starting run with depth={DEPTH}, lambda={LA:.2e}, repetition={rep+1}/{REPS}')\n",
    "            INIT_TYPE = 'ones'\n",
    "            INIT_LR = INIT_LR\n",
    "            INIT = tf.keras.initializers.HeNormal()\n",
    "            EPOCHS = EPOCHS\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}_rep{rep+1}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, f\"depth_{DEPTH}\", f\"rep_{rep+1}\", RUN_NAME)\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed for this repetition\n",
    "            np.random.seed(current_seed)\n",
    "            random.seed(current_seed)\n",
    "            tf.random.set_seed(current_seed)\n",
    "\n",
    "            # Callbacks\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "            # Define model\n",
    "            hadamard_lenet300100 = InpHadamardLeNet300100(input_shape=(X_train.shape[1],), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                              dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                              large_lr_start=LARGE_LRSTART, warmup=WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                   loss='categorical_crossentropy',\n",
    "                   metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_lenet300100.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=0,\n",
    "                                   callbacks=[terminate_nan_cb])\n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "            df, pretrain_sparsity = compute_input_sparsity(hadamard_lenet300100, DEPTH)\n",
    "            pretrain_compression_rate = 1 / (1 - pretrain_sparsity)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "\n",
    "            # Initialize df to store results with added run number column\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Run', 'Pre Opt', 'Depth', 'Lambda', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                            'Pre Epochs', 'Pre Loss', 'Pre Acc', 'Pre Sparsity', 'Pre CR'])\n",
    "\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "                'Run': int(rep + 1),\n",
    "                'Pre Opt': PRETRAIN_OPT,\n",
    "                'Depth': int(DEPTH),\n",
    "                'Lambda': f'{LA:.2e}',\n",
    "                'Init LR': f'{INIT_LR:.2e}',\n",
    "                'LR Schedule': LR_SCHEDULE,\n",
    "                'Batch size': int(BATCH_SIZE),\n",
    "                'Pre Epochs': int(EPOCHS),\n",
    "                'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                'Pre CR': f'{pretrain_compression_rate:.2f}'\n",
    "            }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV\n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}_depth{DEPTH}_rep{rep+1}.csv')\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T13:37:29.320514Z",
     "iopub.status.busy": "2025-01-20T13:37:29.319872Z",
     "iopub.status.idle": "2025-01-20T13:37:30.289533Z",
     "shell.execute_reply": "2025-01-20T13:37:30.289077Z",
     "shell.execute_reply.started": "2025-01-20T13:37:29.320514Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Normalized Training Set Mean and SD: [-3.24146985e-03  3.50020714e-02  8.90986398e-02 -4.09977557e-03\n",
      " -4.60193958e-03 -6.55798614e-03 -2.84731085e-03 -1.16354357e-02\n",
      " -2.95641571e-02 -1.35834853e-03 -2.62191752e-03 -4.16113576e-03\n",
      "  2.39855493e-03 -1.66834856e-03  3.14865564e-03 -2.14544198e-04\n",
      " -6.56517595e-03 -7.31649483e-03 -4.51138802e-03  2.06359359e-03\n",
      " -2.91451986e-04 -9.64563899e-03 -1.72601989e-03  7.53892818e-04\n",
      " -1.55366980e-03  2.08137129e-02  1.08526191e-02  3.57899405e-02\n",
      "  3.76836769e-03  4.23307437e-03  1.31419199e-02 -1.25490846e-02\n",
      " -1.18861953e-02 -9.49829724e-03 -3.21947746e-02 -1.88831203e-02\n",
      " -7.12603796e-03  1.03516066e-02  1.64518468e-02  1.68899540e-02\n",
      "  1.84550378e-02  2.45181005e-02  2.35549156e-02  1.57501064e-02\n",
      "  4.95406287e-03  1.62830036e-02  2.16211341e-02  1.20076323e-02\n",
      " -4.41346411e-03  8.65958538e-03  1.57369561e-02  1.87500492e-02\n",
      "  1.80054680e-02  1.23995300e-02  9.10846982e-03  9.73001216e-03\n",
      "  7.49451807e-03  7.73095991e-03 -7.68886693e-03 -1.76386852e-02\n",
      " -1.36168897e-02 -3.20894197e-02 -9.53118689e-03  8.97258148e-03\n",
      "  1.13728810e-02  5.11359936e-03  8.43429100e-03  2.33553145e-02\n",
      "  1.24908397e-02  7.23513169e-03  5.25113055e-03  7.02910777e-03\n",
      "  1.67902522e-02  1.41728567e-02  2.00983752e-02  1.68732293e-02\n",
      "  7.41685508e-03  1.08532904e-05  1.16455620e-02  9.77592077e-04\n",
      "  7.38237100e-03  5.60628669e-03  1.18602896e-02  9.36703477e-03\n",
      " -4.71358048e-03  1.81554165e-03 -2.17479113e-02 -1.36672100e-02\n",
      " -2.12533567e-02 -1.13295233e-02  6.89093163e-03  1.59896556e-02\n",
      "  1.38822496e-02  8.67076125e-03  1.97429340e-02  1.00746788e-02\n",
      "  6.69226097e-03  1.54435281e-02  1.88798215e-02  1.07481321e-02\n",
      "  9.39126965e-03  1.79504193e-02  1.99522097e-02  1.92106236e-02\n",
      "  1.19977500e-02  6.41437341e-03  3.76900542e-04  6.62188884e-03\n",
      "  1.12041933e-02  1.54783539e-02  6.16374006e-03 -1.34121208e-03\n",
      "  1.04355272e-02  4.90792794e-03 -1.27791939e-02 -2.06226222e-02\n",
      " -2.18296852e-02  7.68736564e-03  2.90283523e-02  3.08515802e-02\n",
      "  1.01884315e-02  1.41542440e-03  1.95453819e-02  1.34634804e-02\n",
      "  2.05074474e-02  5.60368644e-03  5.59299998e-03  1.02447988e-02\n",
      "  1.50295598e-02  1.46457748e-02  1.88842174e-02  7.00653298e-03\n",
      "  1.17180850e-02  1.72173493e-02  4.94886748e-03  8.91567580e-03\n",
      "  8.86446331e-03  4.07411670e-03 -7.85611477e-03  6.26306981e-03\n",
      "  2.98483614e-02  2.20476221e-02  6.45708013e-03 -1.73305511e-03\n",
      " -8.36921763e-03  1.27537632e-02  3.50079015e-02  1.81897562e-02\n",
      "  5.50323399e-03  4.25501959e-03  1.21615492e-02  1.03994720e-02\n",
      "  1.37493787e-02  6.36998564e-03  1.12198165e-03  8.18203483e-03\n",
      "  9.96159203e-03  1.26122171e-02  2.76903287e-02  1.49028981e-02\n",
      "  8.35114066e-03  2.01347712e-02  1.60656963e-02  1.00508612e-02\n",
      "  3.92063661e-03 -3.08877463e-03 -4.28652717e-03  3.88022847e-02\n",
      "  1.47320926e-02  1.38088586e-02  2.04025418e-03 -6.06773049e-03\n",
      " -2.20272108e-03  8.42715055e-03  2.30567474e-02  1.59115642e-02\n",
      "  4.40632086e-03  6.28947862e-04  8.67252611e-03  2.50084000e-03\n",
      "  1.05789572e-03  5.27262175e-03  9.27112158e-03  1.47027392e-02\n",
      "  2.03473065e-02  1.04838228e-02  3.45936753e-02  1.15300072e-02\n",
      "  1.17647303e-02  1.72595959e-02  1.80222429e-02  1.54447062e-02\n",
      "  2.32500732e-02  8.19552317e-03  2.29843184e-02  4.27465849e-02\n",
      " -2.95497361e-03 -7.72496976e-04 -1.42707033e-02 -1.77589860e-02\n",
      " -2.76548485e-03  6.00193255e-03  1.82977803e-02  5.91139914e-03\n",
      " -5.05407155e-03 -4.24689194e-03  1.02214620e-03  7.35726906e-04\n",
      " -3.99989309e-03  6.44025113e-03  2.03928500e-02  1.07727181e-02\n",
      "  1.47489663e-02  2.01325230e-02  3.08272131e-02  1.87509488e-02\n",
      "  1.31416908e-02  2.16146037e-02  2.48543266e-02  1.95318274e-02\n",
      "  1.90912019e-02  1.15282759e-02  1.72114074e-02  1.75281093e-02\n",
      " -4.34914342e-04  6.49112801e-04  2.17957981e-03 -9.66576114e-03\n",
      "  7.31824338e-03  1.40350228e-02  2.39299815e-02  1.09695103e-02\n",
      " -9.12690442e-03 -7.38857826e-03  1.82140979e-03  5.43886842e-03\n",
      " -1.34857828e-02  9.14262794e-03  1.52161345e-02  3.74294212e-03\n",
      "  1.72984730e-02  2.32707839e-02  2.06036624e-02  2.52944101e-02\n",
      "  1.84361860e-02  2.30238661e-02  2.64812056e-02  2.16659699e-02\n",
      "  1.88271310e-02  2.49353610e-02  7.91690592e-03  1.28865913e-02\n",
      "  9.83919017e-03  1.89277809e-03 -1.29288423e-03 -1.91338325e-03\n",
      "  6.95666205e-03  1.22104548e-02  2.29894146e-02  1.90817639e-02\n",
      " -4.40158090e-03 -2.65659275e-03  7.84028042e-03  6.04661996e-04\n",
      " -1.19264470e-02  1.02456398e-02  1.24818599e-02  2.29945639e-03\n",
      "  3.58873489e-03  1.58448108e-02  8.64929892e-03  2.22379118e-02\n",
      "  2.18818486e-02  3.26141864e-02  3.26328799e-02  2.39232313e-02\n",
      "  7.94191100e-03  1.19983032e-02  3.13847908e-03  7.41759222e-03\n",
      "  9.85473767e-03 -3.15799913e-03 -5.38659748e-03 -8.41633137e-03\n",
      "  6.58888696e-03  1.09175975e-02  2.26549860e-02  2.52482891e-02\n",
      " -6.36749668e-03 -9.80483741e-03 -2.72569068e-05 -1.06364097e-02\n",
      " -1.72995217e-02 -8.86679045e-04 -1.41735151e-02 -1.48038333e-02\n",
      " -1.61951147e-02 -2.16940721e-03 -1.70281134e-03  1.95160024e-02\n",
      "  2.08090469e-02  3.37962434e-02  2.84436606e-02  2.00534537e-02\n",
      "  8.10593739e-03  3.58286384e-03 -3.47118825e-03  9.60062910e-03\n",
      " -5.99575287e-04 -1.42719625e-02 -6.29611174e-03 -3.69128888e-03\n",
      "  6.98779291e-03  1.21918237e-02  1.45923551e-02  1.79849006e-02\n",
      "  2.11910973e-03 -1.12883020e-02 -3.98944318e-03 -1.71009172e-02\n",
      " -1.38483718e-02 -1.03464490e-02 -1.10036451e-02 -4.15757531e-03\n",
      " -1.51342321e-02  6.59712451e-03  1.65977795e-02  2.09589824e-02\n",
      "  1.18348990e-02  1.68577153e-02  1.51924137e-02  1.15159182e-02\n",
      "  8.59583262e-03 -3.48898838e-03 -1.26731191e-02  1.11808861e-02\n",
      " -1.00415631e-03 -1.11097796e-02 -1.42033317e-03  1.10455193e-02\n",
      "  2.39075869e-02  2.55473144e-02  1.91194322e-02  1.42413601e-02\n",
      " -6.02416834e-03 -9.29465704e-03 -1.03793573e-02 -2.37691570e-02\n",
      " -2.41956767e-02  1.23394364e-02  4.04320331e-03  4.97306697e-03\n",
      " -1.69195943e-02 -3.30131454e-03  2.38798298e-02  1.03203543e-02\n",
      "  7.96440523e-03  3.03245131e-02  1.80220809e-02  2.07187794e-02\n",
      "  1.00715552e-02 -5.75410854e-03 -1.53732589e-02 -1.23584168e-02\n",
      " -1.07629569e-02 -5.15339756e-03 -6.04680413e-03  4.15101275e-03\n",
      "  2.00565215e-02  1.49135971e-02  1.50412302e-02  7.48063577e-03\n",
      " -1.58211552e-02 -9.66805778e-03 -2.35544313e-02 -2.59678252e-02\n",
      " -1.61386244e-02  1.07258204e-02  3.82859539e-03 -5.92057419e-04\n",
      " -5.57956146e-03  2.61169113e-03  1.65661722e-02  3.74642015e-03\n",
      "  8.58489145e-03  3.03768478e-02  1.53019046e-02  1.03019867e-02\n",
      "  3.12992197e-04 -1.12691149e-02 -2.20677238e-02 -3.09290122e-02\n",
      " -1.70237850e-02 -1.91039573e-02 -8.43686424e-03 -7.40929041e-03\n",
      "  3.56637058e-03  9.85611114e-04  1.20396083e-02  4.11019055e-03\n",
      " -2.41639949e-02 -2.07810216e-02 -1.24820014e-02 -1.01963999e-02\n",
      "  6.55351765e-03  2.60825604e-02  3.44907516e-03  4.96170996e-03\n",
      " -1.30345663e-02 -1.31515311e-02  1.66644901e-02  2.16905214e-02\n",
      "  1.26907360e-02  2.94462182e-02  1.22296717e-02  1.66509598e-02\n",
      "  1.22202560e-02 -3.82068451e-03 -1.59763694e-02 -2.55850852e-02\n",
      " -2.51963101e-02 -3.14252637e-02 -1.79041475e-02 -1.35027450e-02\n",
      " -5.96027635e-03 -1.08394457e-03 -9.06872656e-03 -9.99130495e-03\n",
      " -1.71328969e-02  2.71130982e-03 -8.54783878e-03 -1.99966617e-02\n",
      " -1.32298897e-04  1.13134012e-02 -1.42069999e-03  6.09499495e-03\n",
      " -6.20217947e-03 -5.87947387e-03  1.04708737e-02  2.73785535e-02\n",
      "  1.15972254e-02  2.83443872e-02  1.17514599e-02  1.56214004e-02\n",
      "  5.76471537e-03 -4.49284445e-03 -1.56441629e-02 -2.02929489e-02\n",
      " -3.56063209e-02 -3.05611212e-02 -1.24331368e-02 -1.29128229e-02\n",
      " -3.24089848e-03  1.45510770e-03 -1.85387419e-03 -1.12318108e-03\n",
      " -3.94390104e-03 -6.45412307e-04 -7.09560979e-03 -1.73076075e-02\n",
      " -1.01463608e-02  1.08109219e-02 -1.36665627e-03  1.05844373e-02\n",
      " -1.05727687e-02 -1.66294561e-03  2.43370086e-02  2.56550945e-02\n",
      "  1.20375091e-02  1.77099835e-02  4.27499227e-03  1.09973811e-02\n",
      "  4.62335674e-03 -2.79790000e-03 -7.55562168e-03 -2.47008540e-02\n",
      " -3.89256030e-02 -5.94134629e-03  2.00134721e-02  1.14927776e-02\n",
      "  1.59125105e-02  1.74571723e-02  1.99880432e-02  1.39567014e-02\n",
      "  2.39219167e-03  1.07850339e-02  3.30125098e-03  1.06469644e-02\n",
      "  9.64021217e-03  1.72360297e-02  3.27369384e-03  1.20961592e-02\n",
      "  1.32001226e-03  4.65136254e-03  2.73462608e-02  2.23633293e-02\n",
      "  1.72815826e-02  2.40312461e-02  1.90236736e-02  2.74667758e-02\n",
      "  1.79167539e-02  1.45224845e-02  2.52217194e-03 -1.40289366e-02\n",
      " -9.49648954e-03  1.67570896e-02  2.48924308e-02  1.48510197e-02\n",
      "  2.34984756e-02  1.84107590e-02  2.29204185e-02  1.09700449e-02\n",
      "  4.10654163e-03  8.00897181e-03 -1.14461277e-02  7.69499596e-03\n",
      "  1.85023565e-02  1.62013453e-02 -3.62084783e-03  1.24567933e-02\n",
      "  1.07568558e-02  7.87689630e-03  2.42281985e-02  1.92865636e-02\n",
      "  1.68159809e-02  2.73622889e-02  2.07526591e-02  2.78006960e-02\n",
      "  2.50637215e-02  1.67445894e-02  1.85114460e-03 -4.79112007e-03\n",
      "  2.43378561e-02  3.55292112e-02  3.01191863e-02  2.29799319e-02\n",
      "  2.80635692e-02  2.37510670e-02  2.91995443e-02  1.13418540e-02\n",
      "  1.04357377e-02  2.38615759e-02  6.22076448e-03  1.47007788e-02\n",
      "  1.60328764e-02  1.83605179e-02 -1.57727581e-03  1.94946900e-02\n",
      "  1.44812549e-02 -5.07370615e-03  1.08086942e-02  1.82577241e-02\n",
      "  1.46387778e-02  2.64895894e-02  1.64483953e-02  2.15900037e-02\n",
      "  1.58854071e-02  1.51738552e-02 -1.99169246e-03 -4.91505256e-03\n",
      "  2.80089509e-02  2.77092140e-02  3.24208699e-02  3.95269059e-02\n",
      "  4.25007381e-02  4.10222039e-02  4.46563102e-02  1.30418763e-02\n",
      "  1.28516173e-02  2.71496195e-02  1.52195431e-02  2.04199348e-02\n",
      "  1.69910826e-02  1.71835441e-02  9.97835863e-03  3.35642099e-02\n",
      "  2.69292276e-02  1.26692979e-02  1.97466370e-02  2.51621660e-02\n",
      "  1.88389085e-02  3.04785538e-02  2.36136280e-02  2.88658626e-02\n",
      "  2.84945965e-02  3.01410630e-02  1.16217211e-02  1.42262438e-02\n",
      "  5.12350956e-03  1.88326240e-02  2.31008548e-02  3.87077369e-02\n",
      "  3.76534350e-02  3.76285426e-02  3.99692208e-02  2.23133415e-02\n",
      "  1.92212407e-02  2.98664048e-02  2.12706029e-02  3.03554870e-02\n",
      "  3.55004743e-02  2.67215148e-02  1.98274255e-02  3.02497726e-02\n",
      "  1.96499508e-02  1.38431573e-02  2.84225866e-02  3.63588557e-02\n",
      "  3.40980887e-02  3.81107740e-02  3.63259465e-02  3.89332250e-02\n",
      "  3.83234695e-02  3.34627703e-02  1.87947657e-02  1.99926253e-02\n",
      "  1.72581887e-04  1.53934266e-02  1.20200645e-02  2.40869746e-02\n",
      "  3.02629508e-02  3.50647792e-02  4.34984416e-02  1.87423415e-02\n",
      "  2.35664584e-02  2.95071322e-02  1.58030707e-02  2.21294593e-02\n",
      "  2.37706807e-02  2.05446742e-02  1.35012148e-02  2.82926541e-02\n",
      "  1.78736579e-02  1.79263037e-02  2.62293741e-02  3.77412587e-02\n",
      "  3.41196395e-02  3.92291248e-02  3.65459211e-02  4.30990458e-02\n",
      "  4.37668152e-02  2.99241785e-02  1.74579471e-02  2.02124403e-03\n",
      "  6.42167777e-03  1.18750264e-03 -1.11572887e-03  2.40669940e-02\n",
      "  3.19256037e-02  3.48247364e-02  3.97264324e-02  8.61376338e-03\n",
      "  2.78050732e-02  4.28961366e-02  2.06672810e-02  2.12216526e-02\n",
      "  2.26726066e-02  2.28803121e-02  1.50749302e-02  2.55984105e-02\n",
      "  2.39844974e-02  1.78811979e-02  2.81970073e-02  4.05967981e-02\n",
      "  3.61265726e-02  3.42770703e-02  2.72472613e-02  3.08909118e-02\n",
      "  3.59589793e-02  1.85448043e-02  2.03878898e-03 -8.33288953e-03\n",
      " -6.80755312e-03  5.90216042e-03  4.84660501e-03  2.60295924e-02\n",
      "  2.58698370e-02  3.34796458e-02  2.79022995e-02  2.14220677e-03\n",
      "  1.99755952e-02  2.78575700e-02  1.09758209e-02  1.44785084e-02\n",
      "  1.99851617e-02  1.33757759e-02  4.93340567e-03  2.91063245e-02\n",
      "  2.40062922e-02  1.40671898e-02  2.79013030e-02  3.80676314e-02\n",
      "  3.95431705e-02  2.27731597e-02  2.08142437e-02  2.61080209e-02\n",
      "  3.65753546e-02  2.24250723e-02  1.71915395e-03 -2.54349574e-03\n",
      " -1.88358948e-02 -4.83348733e-03  4.89943707e-03  3.38398777e-02\n",
      "  3.06707025e-02  2.99117137e-02  2.78845094e-02  3.66679952e-03\n",
      "  1.32801682e-02  2.13778820e-02  5.27086528e-03  1.03532430e-02\n",
      "  9.75729153e-03  4.41166107e-03  1.35788193e-03  1.98175143e-02\n",
      "  1.44739710e-02  6.88119046e-03  1.69988424e-02  3.06651928e-02\n",
      "  3.05812787e-02  1.71711333e-02  2.72734575e-02  3.22281271e-02\n",
      "  3.65212373e-02  2.23197434e-02  1.15554994e-02  1.09978812e-02\n",
      "  1.00692492e-02  1.91419269e-03 -1.53961172e-03  3.02008409e-02\n",
      "  3.26466858e-02  3.44693437e-02  2.71769017e-02  3.07245390e-03\n",
      " -1.38861826e-03  9.84916370e-03 -5.11447666e-03 -3.41302203e-03\n",
      "  9.93191265e-04 -7.13662850e-03 -8.33349116e-03  8.69519357e-03\n",
      "  4.85636760e-03 -8.41387361e-03 -3.20522278e-03  9.90457926e-03\n",
      "  9.88449063e-03  1.16746910e-02  1.98120940e-02  2.66295653e-02\n",
      "  2.34886818e-02  4.95604426e-03 -2.31104949e-03 -8.09284858e-03\n",
      "  1.06914397e-02  1.58955418e-02  2.28102449e-02  3.38844284e-02\n",
      "  4.22481112e-02  4.30888534e-02  2.75083221e-02  1.28283957e-02\n",
      "  9.26356018e-03  5.10467682e-03 -5.06989844e-03 -3.65820969e-03\n",
      " -5.37981698e-03 -2.26781778e-02 -2.54808310e-02 -2.83276872e-03\n",
      " -6.95864065e-03 -1.92377493e-02 -7.20793800e-03  4.67667216e-03\n",
      "  1.32660307e-02  1.87259074e-02  1.55533925e-02  3.16411480e-02\n",
      "  1.84662919e-02  7.44097307e-03 -9.36161447e-03 -6.52152812e-05] [0.34163132 2.8825371  4.7776055  1.202886   1.3719418  1.1060272\n",
      " 1.0675011  0.9624039  0.92120767 0.99875206 0.9915094  0.9909271\n",
      " 1.0004141  1.01025    1.0072502  1.0052174  0.9949889  0.9976152\n",
      " 0.9906367  1.0347495  1.0017688  0.95344895 1.0267622  1.0648301\n",
      " 1.0212268  1.3716229  1.2784404  2.0379047  0.9515884  0.7838208\n",
      " 1.5207385  0.8608255  0.9014254  0.97698003 0.91502947 0.9507783\n",
      " 0.9856861  0.99907273 1.0016283  1.002395   1.0019016  1.0042993\n",
      " 1.0039384  0.9997358  0.99494874 0.998209   0.99999255 1.0033854\n",
      " 0.99522126 1.0142117  1.0568209  1.0881443  1.0700699  1.1283139\n",
      " 1.0615621  1.4781423  1.2601624  1.243427   0.96709603 0.7528477\n",
      " 0.9132843  0.9005354  0.9644804  1.0005981  0.9984191  0.99588513\n",
      " 0.99508744 0.99810743 0.9990183  0.999021   0.9980292  0.9982056\n",
      " 0.9994361  0.9955613  1.0013857  1.0004328  0.9933739  0.9904939\n",
      " 1.0143939  0.99011385 1.0203717  1.055523   1.0803947  1.078476\n",
      " 0.6503762  1.1050519  0.73130983 0.82767886 0.92774624 0.9536957\n",
      " 1.0018642  1.0003773  0.9976638  0.999025   1.0013348  0.9948815\n",
      " 0.99661344 0.9963912  1.0021583  0.99819094 1.0004979  0.99705356\n",
      " 0.9950517  0.9954362  0.99533063 0.99618286 0.99438167 0.9987807\n",
      " 1.012667   1.0433689  1.0555096  1.0482459  1.0827563  1.0907633\n",
      " 0.91230845 0.8949051  0.93553704 1.0013602  1.0064825  1.0143136\n",
      " 0.99897355 0.99309903 0.9990474  1.00034    1.0015942  0.99945056\n",
      " 0.99610174 0.9964917  1.0028068  0.9985037  0.99706376 0.9913571\n",
      " 1.0070906  1.0050577  0.99925005 1.0121185  1.0120078  1.0086763\n",
      " 0.95488536 1.0019596  1.4736388  1.1827079  1.0439031  0.98412746\n",
      " 0.9739931  1.0035201  1.0096328  1.0054417  0.9972069  0.99335074\n",
      " 0.9966291  0.99869776 0.9979613  1.0016032  0.99531627 0.993832\n",
      " 0.997044   0.9898819  0.9925112  0.99151665 0.9958746  0.9978689\n",
      " 0.9998209  1.0046617  0.9895568  0.9858321  0.9703126  1.2406021\n",
      " 1.1502498  1.0925207  1.037851   0.9961246  0.9815     0.99611944\n",
      " 1.0084759  1.0033221  0.99906886 0.99333143 0.9975484  0.99942446\n",
      " 0.99888337 1.000092   0.9965196  0.99617374 0.99684125 0.998663\n",
      " 0.98889124 0.98897105 0.9973001  0.99785393 1.0031903  1.0141424\n",
      " 1.0278885  1.0172311  1.0488944  1.1750215  1.0275612  1.008323\n",
      " 0.94973004 0.9669768  0.9777952  0.9902357  1.0011364  0.99778044\n",
      " 0.991041   0.9900391  0.99596804 1.0017391  0.99400795 1.0010507\n",
      " 0.999834   0.9959925  0.9904248  0.990696   0.98514503 0.98585004\n",
      " 0.9933993  0.9950062  0.9997383  1.0091287  1.017258   1.0190196\n",
      " 1.021001   1.0338066  1.0462328  1.0032539  0.9919904  0.9733438\n",
      " 0.9927919  0.9932169  1.0013976  0.9939849  0.98927504 0.98624384\n",
      " 0.991564   0.9958591  0.98622084 0.9968191  0.9958584  0.9974931\n",
      " 0.99118006 0.98389524 0.980924   0.98690945 0.98922473 0.99505156\n",
      " 1.0003613  1.0058973  1.0136775  1.0225924  1.0031719  1.0120711\n",
      " 1.0925151  1.0173889  0.981678   0.9810188  0.9882252  0.9884052\n",
      " 1.0013096  0.99478865 0.9908341  0.99002403 0.9903928  0.9905936\n",
      " 0.9890949  0.99675024 0.99220973 0.99085516 0.98647493 0.98240936\n",
      " 0.97594064 0.9841489  0.9916002  0.9955577  0.9968072  0.99969864\n",
      " 1.0046641  1.0079535  0.99871045 0.99183834 1.0862899  0.99989575\n",
      " 0.9790788  0.9670478  0.9876228  0.9878882  0.99765265 1.0023518\n",
      " 0.9857786  0.9902166  0.9940879  0.9969275  0.99236655 0.9925676\n",
      " 0.99517447 0.9994604  1.0026275  0.9836306  0.98323274 0.98611635\n",
      " 0.98954946 0.9965319  0.99785143 0.9992668  0.9986054  1.0011379\n",
      " 0.9850822  1.0103489  0.9862823  0.9665248  0.9699662  0.9771526\n",
      " 0.9890818  0.9861166  0.99292016 0.9990098  0.99099094 0.99139005\n",
      " 0.99472106 0.9942713  0.995663   1.0026093  0.9935571  0.99079776\n",
      " 1.0029157  0.98958737 0.97808915 0.9799537  0.9835831  0.98724025\n",
      " 0.9899151  0.991381   1.0010345  1.0047442  0.98402643 1.0353156\n",
      " 1.0000331  0.96954143 0.991286   0.9970102  1.0107083  0.99851936\n",
      " 0.9937118  0.9996297  0.9861895  0.99897665 0.99038064 1.0061773\n",
      " 1.0049505  0.9974924  0.996812   0.99905515 1.0007389  0.9964219\n",
      " 0.97508186 0.9853236  0.9890863  0.99129725 0.9977917  1.0007589\n",
      " 1.0052241  0.99739134 0.9872238  0.97594184 0.964986   0.98852557\n",
      " 0.9722185  0.9878948  1.0040526  0.993484   0.9978114  0.9972641\n",
      " 0.9853417  1.0020609  1.0010418  0.99801344 0.9983169  0.9933866\n",
      " 0.9934877  0.9918264  0.995294   0.9906158  0.9785699  0.9813419\n",
      " 0.9885428  0.99330914 0.9968937  0.99244803 0.9957053  0.9921149\n",
      " 0.9730557  0.9440634  0.96058136 0.95145935 0.9861783  0.97903985\n",
      " 0.99227667 0.99070215 1.002114   1.0026731  0.9875257  1.0033667\n",
      " 0.9946808  0.995448   0.99193466 0.99022216 0.98919255 0.9915422\n",
      " 1.0009762  1.0014113  0.9843424  0.9830724  0.9909889  0.99151665\n",
      " 0.9913896  0.9898687  0.99582016 0.98871076 0.9832403  0.95171726\n",
      " 0.9483449  0.96141857 0.9755412  0.9720953  0.9885789  0.9946544\n",
      " 0.9966124  1.0013845  0.98855233 0.9937451  1.0022899  0.99852186\n",
      " 0.9945291  0.9958089  0.9903385  0.9893972  0.99482304 0.991434\n",
      " 0.9890409  0.97764885 0.9950104  0.9851302  0.9924433  0.9880627\n",
      " 0.993642   0.9922749  0.98752713 0.9566819  0.93808866 0.9668311\n",
      " 0.9843035  0.9785005  0.9898536  0.99343824 0.9991571  0.99333763\n",
      " 0.9917627  0.9961083  0.99476165 0.9958428  0.9879223  0.9841462\n",
      " 0.987889   0.98777926 0.99704355 0.980869   0.98158824 0.9787803\n",
      " 0.9912546  0.97897387 0.99146366 0.989967   0.98968095 0.99339294\n",
      " 0.99461323 0.9634232  0.912578   0.99356216 1.020331   0.9987693\n",
      " 0.9996603  0.9926737  0.9974039  0.99489194 0.989417   0.99273336\n",
      " 0.98712146 0.9759693  0.9744912  0.977473   0.98643434 0.98138195\n",
      " 0.9880261  0.98495233 0.9849452  0.9814122  0.9946362  0.9892278\n",
      " 0.9930743  0.9966928  0.99831676 1.0073411  1.0063328  0.9714564\n",
      " 0.97804964 1.0290312  1.0305907  0.9993147  1.0044451  0.9961607\n",
      " 0.99966925 0.9983937  0.9893737  1.0013511  1.0019053  0.9904576\n",
      " 0.99015963 0.98535615 0.9934253  0.99416286 0.9895963  0.9864537\n",
      " 0.9909714  0.9835283  0.99958533 0.995615   0.99505514 1.0004125\n",
      " 1.0011015  1.0080009  1.0045375  1.0137721  1.0463296  1.0537454\n",
      " 1.0427424  1.0070938  1.0052596  0.9957208  0.9993689  0.99711025\n",
      " 0.99186665 0.9912989  0.98906314 0.9836053  0.9890593  0.9890001\n",
      " 0.99110913 0.9804414  0.9826679  0.98966336 0.9921642  0.9847602\n",
      " 0.9983602  0.9981995  0.9931545  0.99940467 0.9961885  1.0044537\n",
      " 0.9937532  1.0123914  1.0453718  1.0238068  1.0292747  1.0116467\n",
      " 1.007397   1.0042015  1.0027792  0.9936927  0.9933813  0.99279535\n",
      " 0.98719627 0.98071307 0.9886545  0.98436934 0.9883705  0.9778528\n",
      " 0.9757723  0.98008686 0.98964185 0.9831205  0.99599886 0.9974546\n",
      " 0.9942136  1.0025455  1.0050485  1.0156956  1.0082211  1.021867\n",
      " 0.9904788  1.0198907  1.0156572  1.0105981  1.0055224  0.99868023\n",
      " 0.993994   0.98693055 0.987966   0.98614806 0.97610205 0.97396815\n",
      " 0.97437346 0.9809658  0.97821635 0.9752058  0.9785474  0.97899896\n",
      " 0.9935334  0.98939043 0.9934195  0.99442536 0.9965002  1.0013795\n",
      " 1.003862   1.013178   1.0120479  1.0297183  0.962481   1.0080234\n",
      " 1.006301   1.0028533  1.0048118  1.0020101  1.0048766  0.9892346\n",
      " 0.99400145 0.98803973 0.9836311  0.9835823  0.98081636 0.9848241\n",
      " 0.9838497  0.98558074 0.9874505  0.9843277  0.9940389  0.99536103\n",
      " 0.99625003 1.0011921  1.0029663  1.0110984  1.0142927  1.0126976\n",
      " 1.0116532  0.99614227 1.0205554  0.991259   0.9850687  1.0066237\n",
      " 1.0114795  1.009902   1.0108314  0.9867525  0.9965512  0.9968671\n",
      " 0.9902945  0.99204487 0.9898899  0.9931242  0.996868   0.9962326\n",
      " 0.9967361  0.9907822  1.001823   1.0023135  1.002914   1.0084324\n",
      " 1.0012441  1.0067382  1.0099068  1.00208    0.98931986 0.94887984\n",
      " 0.9283885  1.0023401  0.99384695 1.0144958  1.0066131  1.006102\n",
      " 1.0043875  0.98678595 0.9958674  0.9971743  0.9914875  0.9899426\n",
      " 0.9896472  0.991268   0.98924166 0.99697375 0.99433744 0.98820686\n",
      " 1.000378   1.0027047  1.0047442  1.0014641  1.0044034  1.010077\n",
      " 1.019073   1.013722   0.9974459  0.9494389  0.7974798  0.939397\n",
      " 0.98772675 1.0311117  1.0189203  1.0086137  1.0075539  0.9967396\n",
      " 0.99439937 0.9982868  0.99749225 0.99630207 0.99494445 0.9943224\n",
      " 0.99443907 0.9990177  0.9975398  0.99229527 1.0024681  1.0045978\n",
      " 1.0058527  1.0035243  1.0120276  1.0186795  1.0288111  1.0168568\n",
      " 1.0197316  0.9930617  1.0945818  0.99559486 0.9831154  1.0252088\n",
      " 1.025227   1.0200776  1.0145788  0.9946202  0.9922832  1.0014904\n",
      " 0.9988159  0.99633896 0.99668795 0.9928971  0.9902243  0.99982125\n",
      " 0.9991123  0.98997027 0.9949894  0.9964389  0.99674296 0.9951232\n",
      " 1.0043468  1.0157614  1.0187147  0.9971744  0.9804649  0.8964888\n",
      " 1.7638171  1.1682385  1.1702031  1.0953912  1.0638372  1.0500901\n",
      " 1.021507   1.0122644  0.99934787 0.996308   0.99086064 0.99566996\n",
      " 0.99485683 0.98841107 0.9855003  0.99762714 1.0001235  0.9898178\n",
      " 0.99302685 0.99238765 1.0117711  1.0130814  1.0085759  1.0431529\n",
      " 1.0290791  1.0101016  0.90654343 1.1259955 ]\n",
      "Train data shape:  (4000, 784)\n",
      "Train labels shape:  (4000,)\n",
      "Test data shape:  (6000, 784)\n",
      "Test labels shape:  (6000,)\n"
     ]
    }
   ],
   "source": [
    "# HSIC lasso + SVM (following Ziyin and Liu, 2023)\n",
    "\n",
    "(X_train, Y_train), (X_test, Y_test) = load_fashion(one_hot=False)\n",
    "############\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0))  \n",
    "    stds = np.std(dataset, axis=(0))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "#print('Validation data shape: ', X_val.shape)\n",
    "#print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-01-20T13:37:48.802128Z",
     "iopub.status.busy": "2025-01-20T13:37:48.801484Z",
     "iopub.status.idle": "2025-01-20T16:04:12.903083Z",
     "shell.execute_reply": "2025-01-20T16:04:12.902419Z",
     "shell.execute_reply.started": "2025-01-20T13:37:48.802099Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Starting repetition 1/5\n",
      "Loading dataset: FMNIST for repetition 1/5\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on FMNIST (repetition 1)\n",
      "Repetition 1: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.010204081632653073, 0.030612244897959218, 0.06887755102040816, 0.14923469387755106, 0.21811224489795922, 0.27933673469387754, 0.32780612244897955, 0.3698979591836735, 0.3966836734693877, 0.4119897959183674, 0.42091836734693877, 0.42984693877551017, 0.4375, 0.44515306122448983, 0.45663265306122447, 0.46556122448979587, 0.47193877551020413, 0.4821428571428571, 0.48852040816326525, 0.49234693877551017, 0.4987244897959183, 0.5114795918367347, 0.5165816326530612, 0.5255102040816326, 0.5306122448979591, 0.5408163265306123, 0.5459183673469388, 0.5573979591836735, 0.5663265306122449, 0.5701530612244898, 0.5714285714285714, 0.5803571428571428, 0.5918367346938775, 0.5931122448979591, 0.5994897959183674, 0.6045918367346939, 0.6096938775510203, 0.6109693877551021, 0.6173469387755102, 0.6186224489795918, 0.6275510204081632, 0.625, 0.6211734693877551, 0.6352040816326531, 0.6313775510204082, 0.6403061224489797, 0.6377551020408163, 0.6454081632653061, 0.6428571428571428, 0.6568877551020409, 0.6543367346938775, 0.6658163265306123, 0.6632653061224489, 0.6696428571428572, 0.6670918367346939, 0.6785714285714286, 0.6760204081632653, 0.6836734693877551, 0.6875, 0.6977040816326531, 0.6875, 0.7028061224489797, 0.6951530612244898, 0.7091836734693877, 0.7066326530612245, 0.7181122448979591, 0.7168367346938775, 0.7295918367346939, 0.7257653061224489, 0.7372448979591837, 0.7334183673469388, 0.7461734693877551, 0.7372448979591837, 0.75, 0.7487244897959184, 0.7614795918367347, 0.7563775510204082, 0.7716836734693877, 0.7602040816326531, 0.778061224489796, 0.7678571428571428, 0.7818877551020408, 0.7716836734693877, 0.7831632653061225, 0.7716836734693877, 0.7857142857142857, 0.7767857142857143, 0.7920918367346939, 0.7793367346938775, 0.7920918367346939, 0.7806122448979592, 0.7908163265306123, 0.7933673469387755, 0.798469387755102, 0.7997448979591837, 0.798469387755102, 0.7997448979591837, 0.8035714285714286, 0.8048469387755102, 0.8073979591836735, 0.8125, 0.8176020408163265, 0.8227040816326531, 0.8239795918367347, 0.8227040816326531, 0.8252551020408163, 0.8290816326530612, 0.8278061224489796, 0.8290816326530612, 0.8329081632653061, 0.8329081632653061, 0.8380102040816326, 0.840561224489796, 0.8418367346938775, 0.846938775510204, 0.8494897959183674, 0.8533163265306123, 0.8520408163265306, 0.8545918367346939, 0.8533163265306123, 0.8533163265306123, 0.8558673469387755, 0.8596938775510204, 0.8596938775510204, 0.8647959183673469, 0.8673469387755102, 0.8711734693877551, 0.8711734693877551, 0.8737244897959184, 0.875, 0.875, 0.8762755102040817, 0.8762755102040817, 0.8775510204081632, 0.8801020408163265, 0.8788265306122449, 0.8788265306122449, 0.8775510204081632, 0.8826530612244898, 0.8788265306122449, 0.8775510204081632, 0.8801020408163265, 0.8826530612244898, 0.8852040816326531, 0.8852040816326531, 0.889030612244898, 0.889030612244898, 0.8915816326530612, 0.8903061224489796, 0.8928571428571429, 0.8979591836734694, 0.9017857142857143, 0.903061224489796, 0.9068877551020408, 0.909438775510204, 0.9107142857142857, 0.9145408163265306, 0.9145408163265306, 0.9158163265306123, 0.9183673469387755, 0.9209183673469388, 0.9209183673469388, 0.9209183673469388, 0.9209183673469388, 0.9183673469387755, 0.9272959183673469, 0.9272959183673469, 0.923469387755102, 0.9311224489795918, 0.9285714285714286, 0.9323979591836735, 0.9323979591836735, 0.9336734693877551, 0.9323979591836735, 0.9362244897959183, 0.9349489795918368, 0.9387755102040817, 0.9413265306122449, 0.9413265306122449, 0.9438775510204082, 0.9426020408163265, 0.9438775510204082, 0.9438775510204082, 0.9413265306122449, 0.9426020408163265, 0.9426020408163265, 0.9438775510204082, 0.9413265306122449, 0.9451530612244898, 0.9489795918367347, 0.9413265306122449, 0.9489795918367347, 0.9502551020408163, 0.9502551020408163, 0.9553571428571429, 0.9540816326530612, 0.9566326530612245, 0.9566326530612245, 0.9566326530612245, 0.9579081632653061, 0.9630102040816326, 0.9630102040816326, 0.9630102040816326, 0.965561224489796, 0.9668367346938775, 0.9693877551020408, 0.9693877551020408, 0.9693877551020408, 0.9693877551020408, 0.9693877551020408, 0.9706632653061225, 0.9719387755102041, 0.9719387755102041, 0.9732142857142857, 0.9732142857142857, 0.9744897959183674, 0.9744897959183674, 0.9757653061224489, 0.9744897959183674, 0.9770408163265306, 0.9783163265306123, 0.9744897959183674, 0.9783163265306123, 0.9757653061224489, 0.9783163265306123, 0.9770408163265306, 0.9783163265306123, 0.9821428571428571, 0.9795918367346939, 0.9821428571428571, 0.9795918367346939, 0.9834183673469388, 0.9834183673469388, 0.9795918367346939, 0.9834183673469388, 0.9834183673469388, 0.9834183673469388, 0.9872448979591837, 0.9846938775510204, 0.985969387755102, 0.9897959183673469, 0.9897959183673469, 0.9910714285714286, 0.9936224489795918, 0.9948979591836735, 0.9961734693877551, 0.9961734693877551, 0.9987244897959183, 1.0] and test accuracy = [0.8273333333333334, 0.8226666666666667, 0.8243333333333334, 0.825, 0.825, 0.825, 0.8251666666666667, 0.8251666666666667, 0.8256666666666667, 0.8256666666666667, 0.8256666666666667, 0.8256666666666667, 0.8256666666666667, 0.8258333333333333, 0.8258333333333333, 0.826, 0.826, 0.8258333333333333, 0.8256666666666667, 0.8256666666666667, 0.8256666666666667, 0.826, 0.826, 0.8263333333333334, 0.826, 0.8256666666666667, 0.8263333333333334, 0.8263333333333334, 0.8261666666666667, 0.826, 0.826, 0.8261666666666667, 0.8258333333333333, 0.8258333333333333, 0.826, 0.8261666666666667, 0.8263333333333334, 0.8268333333333333, 0.8265, 0.826, 0.826, 0.826, 0.8263333333333334, 0.8268333333333333, 0.8273333333333334, 0.8278333333333333, 0.828, 0.8278333333333333, 0.8281666666666667, 0.828, 0.8276666666666667, 0.8276666666666667, 0.8276666666666667, 0.8266666666666667, 0.8266666666666667, 0.8253333333333334, 0.8258333333333333, 0.8246666666666667, 0.8241666666666667, 0.8245, 0.8238333333333333, 0.8253333333333334, 0.8243333333333334, 0.8236666666666667, 0.8231666666666667, 0.822, 0.8176666666666667, 0.8151666666666667, 0.8135, 0.8118333333333333, 0.8103333333333333, 0.81, 0.8091666666666667, 0.8096666666666666, 0.8098333333333333, 0.8098333333333333, 0.809, 0.809, 0.8081666666666667, 0.8076666666666666, 0.8076666666666666, 0.808, 0.8078333333333333, 0.8073333333333333, 0.8073333333333333, 0.8073333333333333, 0.807, 0.8068333333333333, 0.8063333333333333, 0.8065, 0.806, 0.8065, 0.8058333333333333, 0.8061666666666667, 0.8058333333333333, 0.8055, 0.8053333333333333, 0.805, 0.805, 0.8045, 0.8036666666666666, 0.8026666666666666, 0.8023333333333333, 0.8006666666666666, 0.8001666666666667, 0.799, 0.7988333333333333, 0.7985, 0.7981666666666667, 0.798, 0.7976666666666666, 0.7966666666666666, 0.797, 0.7963333333333333, 0.796, 0.7956666666666666, 0.795, 0.7935, 0.793, 0.7928333333333333, 0.7926666666666666, 0.7926666666666666, 0.791, 0.7913333333333333, 0.791, 0.7901666666666667, 0.7898333333333334, 0.7893333333333333, 0.7893333333333333, 0.7885, 0.7883333333333333, 0.7883333333333333, 0.7878333333333334, 0.7871666666666667, 0.7871666666666667, 0.7868333333333334, 0.7868333333333334, 0.7861666666666667, 0.786, 0.7856666666666666, 0.7848333333333334, 0.7838333333333334, 0.784, 0.7841666666666667, 0.7831666666666667, 0.7823333333333333, 0.7821666666666667, 0.781, 0.7805, 0.78, 0.7798333333333334, 0.7793333333333333, 0.7785, 0.778, 0.7765, 0.776, 0.7755, 0.7743333333333333, 0.7731666666666667, 0.7721666666666667, 0.7715, 0.7711666666666667, 0.7706666666666667, 0.77, 0.7691666666666667, 0.7696666666666667, 0.7693333333333333, 0.7686666666666667, 0.7675, 0.767, 0.7668333333333334, 0.7661666666666667, 0.765, 0.7645, 0.7633333333333333, 0.7628333333333334, 0.7615, 0.7618333333333334, 0.761, 0.761, 0.7608333333333334, 0.7605, 0.7598333333333334, 0.7595, 0.7585, 0.758, 0.7581666666666667, 0.7575, 0.7571666666666667, 0.7563333333333333, 0.7551666666666667, 0.7551666666666667, 0.7538333333333334, 0.753, 0.7516666666666667, 0.7511666666666666, 0.7511666666666666, 0.7505, 0.7498333333333334, 0.7491666666666666, 0.7481666666666666, 0.747, 0.7465, 0.7455, 0.7443333333333333, 0.7431666666666666, 0.741, 0.74, 0.7385, 0.7371666666666666, 0.7368333333333333, 0.7366666666666667, 0.7346666666666667, 0.7323333333333333, 0.7318333333333333, 0.7298333333333333, 0.729, 0.7281666666666666, 0.7276666666666667, 0.725, 0.7243333333333334, 0.7223333333333334, 0.72, 0.7161666666666666, 0.713, 0.712, 0.7085, 0.704, 0.7011666666666667, 0.6988333333333333, 0.6975, 0.6945, 0.6931666666666667, 0.6888333333333333, 0.6846666666666666, 0.6801666666666667, 0.6771666666666667, 0.6738333333333333, 0.6693333333333333, 0.6646666666666666, 0.662, 0.6601666666666667, 0.655, 0.6445, 0.6355, 0.6271666666666667, 0.6198333333333333, 0.6135, 0.605, 0.6016666666666667, 0.5926666666666667, 0.5845, 0.5761666666666667, 0.5643333333333334, 0.553, 0.5448333333333333, 0.5388333333333334, 0.5338333333333334, 0.5278333333333334, 0.5188333333333334, 0.5081666666666667, 0.49883333333333335, 0.4891666666666667, 0.4785, 0.4736666666666667, 0.469, 0.4623333333333333, 0.456, 0.4455, 0.43233333333333335, 0.4146666666666667, 0.373, 0.3635, 0.3585, 0.3536666666666667, 0.3456666666666667, 0.3415, 0.33366666666666667, 0.3295, 0.3278333333333333, 0.326, 0.32083333333333336, 0.31583333333333335, 0.31033333333333335, 0.296, 0.2925, 0.2851666666666667, 0.2826666666666667, 0.2793333333333333, 0.2755, 0.27366666666666667, 0.27266666666666667, 0.27166666666666667, 0.26966666666666667, 0.259, 0.234, 0.21016666666666667, 0.1845, 0.18366666666666667, 0.18366666666666667, 0.1865, 0.184, 0.182, 0.1735, 0.17233333333333334, 0.17166666666666666, 0.169, 0.16483333333333333, 0.09566666666666666, 0.09566666666666666]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/LassoNet/rep_1/FMNIST_LassoNet_rep1_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on FMNIST (repetition 1)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.8937 - accuracy: 0.2092\n",
      "test acc for vanilla model is [1.8936635255813599, 0.20916666090488434]\n",
      "1.8936635255813599 0.20916666090488434\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.2733 - accuracy: 0.4792\n",
      "test acc for vanilla model is [1.2732633352279663, 0.4791666567325592]\n",
      "1.2732633352279663 0.4791666567325592\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 1s 1ms/step - loss: 1.3189 - accuracy: 0.5085\n",
      "test acc for vanilla model is [1.3189114332199097, 0.5084999799728394]\n",
      "1.3189114332199097 0.5084999799728394\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.3016 - accuracy: 0.6913\n",
      "test acc for vanilla model is [1.3016175031661987, 0.6913333535194397]\n",
      "1.3016175031661987 0.6913333535194397\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0162 - accuracy: 0.8125\n",
      "test acc for vanilla model is [1.0162208080291748, 0.8125]\n",
      "1.0162208080291748 0.8125\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0485 - accuracy: 0.8197\n",
      "test acc for vanilla model is [1.0485023260116577, 0.8196666836738586]\n",
      "1.0485023260116577 0.8196666836738586\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0665 - accuracy: 0.8193\n",
      "test acc for vanilla model is [1.0664904117584229, 0.8193333148956299]\n",
      "1.0664904117584229 0.8193333148956299\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0487 - accuracy: 0.8170\n",
      "test acc for vanilla model is [1.0486897230148315, 0.8169999718666077]\n",
      "1.0486897230148315 0.8169999718666077\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0154 - accuracy: 0.8215\n",
      "test acc for vanilla model is [1.0153988599777222, 0.8215000033378601]\n",
      "1.0153988599777222 0.8215000033378601\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0638 - accuracy: 0.8110\n",
      "test acc for vanilla model is [1.0638443231582642, 0.8109999895095825]\n",
      "1.0638443231582642 0.8109999895095825\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0703 - accuracy: 0.8200\n",
      "test acc for vanilla model is [1.070309042930603, 0.8199999928474426]\n",
      "1.070309042930603 0.8199999928474426\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0594 - accuracy: 0.8193\n",
      "test acc for vanilla model is [1.059424638748169, 0.8193333148956299]\n",
      "1.059424638748169 0.8193333148956299\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0237 - accuracy: 0.8225\n",
      "test acc for vanilla model is [1.0237088203430176, 0.8224999904632568]\n",
      "1.0237088203430176 0.8224999904632568\n",
      "Repetition 1: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.20916666090488434, 0.4791666567325592, 0.5084999799728394, 0.6913333535194397, 0.8125, 0.8196666836738586, 0.8193333148956299, 0.8169999718666077, 0.8215000033378601, 0.8109999895095825, 0.8199999928474426, 0.8193333148956299, 0.8224999904632568]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_dnn/rep_1/FMNIST_HSIC_dnn_rep1_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 1\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on FMNIST (repetition 1)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 1: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.206, 0.4673333333333333, 0.496, 0.7031666666666667, 0.822, 0.8285, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_svm/rep_1/FMNIST_HSIC_svm_rep1_res.csv\n",
      "\n",
      "Starting repetition 2/5\n",
      "Loading dataset: FMNIST for repetition 2/5\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 2\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on FMNIST (repetition 2)\n",
      "Repetition 2: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.005102040816326481, 0.014030612244897989, 0.05102040816326525, 0.10841836734693877, 0.17984693877551017, 0.2589285714285714, 0.30612244897959184, 0.3494897959183674, 0.3852040816326531, 0.4017857142857143, 0.41326530612244894, 0.4247448979591837, 0.4323979591836735, 0.44515306122448983, 0.4502551020408163, 0.4604591836734694, 0.4642857142857143, 0.4744897959183674, 0.48086734693877553, 0.4897959183673469, 0.4948979591836735, 0.4987244897959183, 0.4987244897959183, 0.5089285714285714, 0.5191326530612245, 0.528061224489796, 0.5331632653061225, 0.5446428571428572, 0.5561224489795918, 0.5612244897959184, 0.5714285714285714, 0.5765306122448979, 0.5867346938775511, 0.590561224489796, 0.5994897959183674, 0.6033163265306123, 0.6109693877551021, 0.6186224489795918, 0.625, 0.6301020408163265, 0.6352040816326531, 0.6466836734693877, 0.6454081632653061, 0.653061224489796, 0.6556122448979591, 0.659438775510204, 0.6645408163265306, 0.6658163265306123, 0.6721938775510203, 0.6747448979591837, 0.6823979591836735, 0.6849489795918368, 0.6849489795918368, 0.6887755102040816, 0.6900510204081632, 0.6951530612244898, 0.6913265306122449, 0.6951530612244898, 0.6913265306122449, 0.7002551020408163, 0.7028061224489797, 0.7142857142857143, 0.715561224489796, 0.7244897959183674, 0.7193877551020409, 0.7359693877551021, 0.7372448979591837, 0.7436224489795918, 0.7372448979591837, 0.7474489795918368, 0.7436224489795918, 0.7563775510204082, 0.7487244897959184, 0.7627551020408163, 0.7614795918367347, 0.7716836734693877, 0.7653061224489796, 0.7767857142857143, 0.7755102040816326, 0.784438775510204, 0.778061224489796, 0.7857142857142857, 0.7831632653061225, 0.7920918367346939, 0.7882653061224489, 0.8022959183673469, 0.7920918367346939, 0.8061224489795918, 0.7971938775510204, 0.8112244897959184, 0.8022959183673469, 0.8112244897959184, 0.8010204081632653, 0.8150510204081632, 0.8073979591836735, 0.8137755102040816, 0.8061224489795918, 0.8163265306122449, 0.8035714285714286, 0.8125, 0.8137755102040816, 0.8163265306122449, 0.8163265306122449, 0.8163265306122449, 0.8150510204081632, 0.8227040816326531, 0.826530612244898, 0.826530612244898, 0.8316326530612245, 0.8354591836734694, 0.8392857142857143, 0.840561224489796, 0.8431122448979592, 0.8456632653061225, 0.8507653061224489, 0.8507653061224489, 0.8558673469387755, 0.8584183673469388, 0.8584183673469388, 0.8584183673469388, 0.8622448979591837, 0.860969387755102, 0.8622448979591837, 0.8635204081632653, 0.8647959183673469, 0.8660714285714286, 0.8673469387755102, 0.8647959183673469, 0.8698979591836735, 0.8724489795918368, 0.875, 0.875, 0.875, 0.8775510204081632, 0.8775510204081632, 0.8788265306122449, 0.8826530612244898, 0.8864795918367347, 0.8903061224489796, 0.8915816326530612, 0.8915816326530612, 0.8928571428571429, 0.8941326530612245, 0.8941326530612245, 0.8941326530612245, 0.8954081632653061, 0.8979591836734694, 0.9005102040816326, 0.9017857142857143, 0.9043367346938775, 0.9043367346938775, 0.9056122448979592, 0.9068877551020408, 0.9081632653061225, 0.9081632653061225, 0.9107142857142857, 0.909438775510204, 0.9107142857142857, 0.9107142857142857, 0.9132653061224489, 0.9145408163265306, 0.9145408163265306, 0.9158163265306123, 0.9170918367346939, 0.9170918367346939, 0.9196428571428571, 0.9209183673469388, 0.9196428571428571, 0.9272959183673469, 0.9285714285714286, 0.9285714285714286, 0.9285714285714286, 0.9285714285714286, 0.9272959183673469, 0.9285714285714286, 0.9349489795918368, 0.9375, 0.9375, 0.9362244897959183, 0.9387755102040817, 0.9400510204081632, 0.9413265306122449, 0.9413265306122449, 0.9413265306122449, 0.9426020408163265, 0.9426020408163265, 0.9426020408163265, 0.9426020408163265, 0.9477040816326531, 0.9464285714285714, 0.9502551020408163, 0.951530612244898, 0.9528061224489796, 0.9528061224489796, 0.9553571428571429, 0.9528061224489796, 0.9566326530612245, 0.9566326530612245, 0.9566326530612245, 0.9579081632653061, 0.9604591836734694, 0.9604591836734694, 0.9617346938775511, 0.9642857142857143, 0.9642857142857143, 0.9681122448979592, 0.9681122448979592, 0.9681122448979592, 0.9681122448979592, 0.9681122448979592, 0.9719387755102041, 0.9732142857142857, 0.9732142857142857, 0.9732142857142857, 0.9693877551020408, 0.9770408163265306, 0.9757653061224489, 0.9757653061224489, 0.9770408163265306, 0.9770408163265306, 0.9744897959183674, 0.9770408163265306, 0.9770408163265306, 0.9744897959183674, 0.9783163265306123, 0.9795918367346939, 0.9770408163265306, 0.9783163265306123, 0.9795918367346939, 0.9795918367346939, 0.9795918367346939, 0.9783163265306123, 0.9808673469387755, 0.9795918367346939, 0.9834183673469388, 0.985969387755102, 0.985969387755102, 0.9885204081632653, 0.9885204081632653, 0.9885204081632653, 0.9910714285714286, 0.9923469387755102, 0.9936224489795918, 0.9948979591836735, 0.9974489795918368, 0.9974489795918368, 0.9974489795918368, 1.0] and test accuracy = [0.8203333333333334, 0.817, 0.8185, 0.8193333333333334, 0.8193333333333334, 0.8188333333333333, 0.8185, 0.8178333333333333, 0.818, 0.8178333333333333, 0.8176666666666667, 0.818, 0.8178333333333333, 0.8176666666666667, 0.8178333333333333, 0.818, 0.8183333333333334, 0.8186666666666667, 0.8186666666666667, 0.819, 0.8195, 0.8195, 0.8196666666666667, 0.82, 0.8198333333333333, 0.8193333333333334, 0.8191666666666667, 0.8191666666666667, 0.8195, 0.8196666666666667, 0.8196666666666667, 0.8196666666666667, 0.8196666666666667, 0.8191666666666667, 0.819, 0.8193333333333334, 0.8196666666666667, 0.8195, 0.8195, 0.8198333333333333, 0.8201666666666667, 0.8203333333333334, 0.8206666666666667, 0.8206666666666667, 0.8205, 0.8206666666666667, 0.8201666666666667, 0.8201666666666667, 0.8208333333333333, 0.8215, 0.8213333333333334, 0.8213333333333334, 0.8218333333333333, 0.8233333333333334, 0.8228333333333333, 0.8236666666666667, 0.8243333333333334, 0.8235, 0.8243333333333334, 0.8245, 0.8255, 0.8263333333333334, 0.8265, 0.8225, 0.823, 0.8206666666666667, 0.817, 0.8145, 0.8126666666666666, 0.8111666666666667, 0.8115, 0.8116666666666666, 0.8118333333333333, 0.8106666666666666, 0.8091666666666667, 0.8086666666666666, 0.8083333333333333, 0.808, 0.8076666666666666, 0.8071666666666667, 0.8066666666666666, 0.8063333333333333, 0.8056666666666666, 0.8048333333333333, 0.804, 0.8033333333333333, 0.803, 0.8033333333333333, 0.8028333333333333, 0.8031666666666667, 0.8031666666666667, 0.803, 0.8026666666666666, 0.8023333333333333, 0.802, 0.8011666666666667, 0.8006666666666666, 0.8, 0.7998333333333333, 0.7995, 0.7996666666666666, 0.799, 0.7978333333333333, 0.7973333333333333, 0.7963333333333333, 0.7956666666666666, 0.795, 0.7948333333333333, 0.795, 0.7945, 0.794, 0.7938333333333333, 0.7933333333333333, 0.7923333333333333, 0.792, 0.792, 0.7905, 0.7898333333333334, 0.7896666666666666, 0.789, 0.7886666666666666, 0.7883333333333333, 0.7881666666666667, 0.7871666666666667, 0.7868333333333334, 0.7863333333333333, 0.786, 0.7855, 0.785, 0.784, 0.7831666666666667, 0.7828333333333334, 0.7818333333333334, 0.7808333333333334, 0.78, 0.7796666666666666, 0.7793333333333333, 0.779, 0.7781666666666667, 0.7773333333333333, 0.7771666666666667, 0.7768333333333334, 0.777, 0.776, 0.7751666666666667, 0.775, 0.7748333333333334, 0.7738333333333334, 0.7733333333333333, 0.7725, 0.7725, 0.7718333333333334, 0.772, 0.7711666666666667, 0.7708333333333334, 0.7701666666666667, 0.7696666666666667, 0.7686666666666667, 0.769, 0.7683333333333333, 0.7676666666666667, 0.7675, 0.7678333333333334, 0.7676666666666667, 0.7666666666666667, 0.766, 0.7653333333333333, 0.7638333333333334, 0.763, 0.763, 0.7625, 0.7606666666666667, 0.7601666666666667, 0.7588333333333334, 0.759, 0.7578333333333334, 0.7573333333333333, 0.7566666666666667, 0.756, 0.7551666666666667, 0.7543333333333333, 0.7536666666666667, 0.7535, 0.7526666666666667, 0.752, 0.7505, 0.75, 0.7495, 0.7486666666666667, 0.7476666666666667, 0.7465, 0.7455, 0.7448333333333333, 0.7443333333333333, 0.7435, 0.7428333333333333, 0.7416666666666667, 0.7411666666666666, 0.7401666666666666, 0.7388333333333333, 0.7376666666666667, 0.7358333333333333, 0.7348333333333333, 0.7331666666666666, 0.732, 0.7308333333333333, 0.7293333333333333, 0.7271666666666666, 0.725, 0.7231666666666666, 0.7218333333333333, 0.72, 0.7188333333333333, 0.7173333333333334, 0.7158333333333333, 0.715, 0.7141666666666666, 0.7131666666666666, 0.712, 0.7111666666666666, 0.71, 0.7088333333333333, 0.7071666666666667, 0.7038333333333333, 0.7001666666666667, 0.6961666666666667, 0.693, 0.6908333333333333, 0.6853333333333333, 0.6823333333333333, 0.6798333333333333, 0.6781666666666667, 0.6748333333333333, 0.6731666666666667, 0.6695, 0.6671666666666667, 0.6653333333333333, 0.6623333333333333, 0.6613333333333333, 0.6595, 0.6566666666666666, 0.6533333333333333, 0.6523333333333333, 0.6485, 0.6468333333333334, 0.644, 0.6391666666666667, 0.6348333333333334, 0.6303333333333333, 0.6256666666666667, 0.6193333333333333, 0.6108333333333333, 0.6036666666666667, 0.5973333333333334, 0.5908333333333333, 0.5855, 0.5801666666666667, 0.5733333333333334, 0.565, 0.5556666666666666, 0.547, 0.5363333333333333, 0.5245, 0.5101666666666667, 0.49216666666666664, 0.47583333333333333, 0.4565, 0.43416666666666665, 0.4053333333333333, 0.378, 0.3626666666666667, 0.3551666666666667, 0.348, 0.3426666666666667, 0.33466666666666667, 0.32866666666666666, 0.32566666666666666, 0.32166666666666666, 0.31933333333333336, 0.31483333333333335, 0.30783333333333335, 0.2965, 0.2831666666666667, 0.276, 0.27, 0.2683333333333333, 0.26616666666666666, 0.2645, 0.26366666666666666, 0.26216666666666666, 0.26316666666666666, 0.263, 0.2643333333333333, 0.18216666666666667, 0.18133333333333335, 0.18066666666666667, 0.18083333333333335, 0.18083333333333335, 0.181, 0.18066666666666667, 0.1805, 0.1805, 0.1805, 0.181, 0.1795, 0.17916666666666667, 0.1795, 0.09566666666666666, 0.09566666666666666]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/LassoNet/rep_2/FMNIST_LassoNet_rep2_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 2\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on FMNIST (repetition 2)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.8937 - accuracy: 0.2090\n",
      "test acc for vanilla model is [1.8936642408370972, 0.20900000631809235]\n",
      "1.8936642408370972 0.20900000631809235\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.2767 - accuracy: 0.4850\n",
      "test acc for vanilla model is [1.2767225503921509, 0.48500001430511475]\n",
      "1.2767225503921509 0.48500001430511475\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.3108 - accuracy: 0.5097\n",
      "test acc for vanilla model is [1.3108198642730713, 0.5096666812896729]\n",
      "1.3108198642730713 0.5096666812896729\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.2943 - accuracy: 0.6890\n",
      "test acc for vanilla model is [1.2943229675292969, 0.6890000104904175]\n",
      "1.2943229675292969 0.6890000104904175\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0155 - accuracy: 0.8122\n",
      "test acc for vanilla model is [1.0155208110809326, 0.812166690826416]\n",
      "1.0155208110809326 0.812166690826416\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0497 - accuracy: 0.8177\n",
      "test acc for vanilla model is [1.0496528148651123, 0.8176666498184204]\n",
      "1.0496528148651123 0.8176666498184204\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0687 - accuracy: 0.8167\n",
      "test acc for vanilla model is [1.0686954259872437, 0.8166666626930237]\n",
      "1.0686954259872437 0.8166666626930237\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0541 - accuracy: 0.8155\n",
      "test acc for vanilla model is [1.0540603399276733, 0.815500020980835]\n",
      "1.0540603399276733 0.815500020980835\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0084 - accuracy: 0.8217\n",
      "test acc for vanilla model is [1.0083684921264648, 0.8216666579246521]\n",
      "1.0083684921264648 0.8216666579246521\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0569 - accuracy: 0.8140\n",
      "test acc for vanilla model is [1.056942105293274, 0.8140000104904175]\n",
      "1.056942105293274 0.8140000104904175\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0627 - accuracy: 0.8188\n",
      "test acc for vanilla model is [1.0626580715179443, 0.8188333511352539]\n",
      "1.0626580715179443 0.8188333511352539\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0538 - accuracy: 0.8185\n",
      "test acc for vanilla model is [1.0538156032562256, 0.8184999823570251]\n",
      "1.0538156032562256 0.8184999823570251\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0215 - accuracy: 0.8232\n",
      "test acc for vanilla model is [1.021511197090149, 0.8231666684150696]\n",
      "1.021511197090149 0.8231666684150696\n",
      "Repetition 2: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.20900000631809235, 0.48500001430511475, 0.5096666812896729, 0.6890000104904175, 0.812166690826416, 0.8176666498184204, 0.8166666626930237, 0.815500020980835, 0.8216666579246521, 0.8140000104904175, 0.8188333511352539, 0.8184999823570251, 0.8231666684150696]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_dnn/rep_2/FMNIST_HSIC_dnn_rep2_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 2\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on FMNIST (repetition 2)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 2: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.206, 0.4673333333333333, 0.496, 0.7031666666666667, 0.822, 0.8285, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_svm/rep_2/FMNIST_HSIC_svm_rep2_res.csv\n",
      "\n",
      "Starting repetition 3/5\n",
      "Loading dataset: FMNIST for repetition 3/5\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 3\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on FMNIST (repetition 3)\n",
      "Repetition 3: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0025510204081632404, 0.02168367346938771, 0.05994897959183676, 0.12882653061224492, 0.19132653061224492, 0.26147959183673475, 0.30612244897959184, 0.34183673469387754, 0.38265306122448983, 0.4005102040816326, 0.40943877551020413, 0.4158163265306123, 0.42091836734693877, 0.43112244897959184, 0.4413265306122449, 0.44897959183673475, 0.4604591836734694, 0.46556122448979587, 0.47576530612244894, 0.48341836734693877, 0.4872448979591837, 0.49362244897959184, 0.4987244897959183, 0.5089285714285714, 0.5204081632653061, 0.5255102040816326, 0.5369897959183674, 0.5420918367346939, 0.5535714285714286, 0.5599489795918368, 0.5625, 0.5625, 0.5688775510204082, 0.5778061224489797, 0.5790816326530612, 0.5918367346938775, 0.5943877551020409, 0.596938775510204, 0.6084183673469388, 0.6135204081632653, 0.6211734693877551, 0.6198979591836735, 0.6301020408163265, 0.6428571428571428, 0.6377551020408163, 0.6479591836734694, 0.6505102040816326, 0.6543367346938775, 0.6581632653061225, 0.6607142857142857, 0.6658163265306123, 0.6721938775510203, 0.6683673469387755, 0.6772959183673469, 0.6772959183673469, 0.6887755102040816, 0.6900510204081632, 0.6977040816326531, 0.6977040816326531, 0.7104591836734694, 0.7104591836734694, 0.7130102040816326, 0.715561224489796, 0.7206632653061225, 0.7142857142857143, 0.7206632653061225, 0.7168367346938775, 0.7232142857142857, 0.7193877551020409, 0.7308673469387755, 0.7308673469387755, 0.7410714285714286, 0.7461734693877551, 0.7525510204081632, 0.7512755102040816, 0.7563775510204082, 0.7614795918367347, 0.7665816326530612, 0.7665816326530612, 0.7767857142857143, 0.778061224489796, 0.7869897959183674, 0.7831632653061225, 0.7933673469387755, 0.7857142857142857, 0.7971938775510204, 0.7946428571428572, 0.7997448979591837, 0.7971938775510204, 0.8035714285714286, 0.798469387755102, 0.8086734693877551, 0.8022959183673469, 0.8137755102040816, 0.8073979591836735, 0.8201530612244898, 0.8112244897959184, 0.8214285714285714, 0.8137755102040816, 0.8150510204081632, 0.8188775510204082, 0.8201530612244898, 0.8188775510204082, 0.8214285714285714, 0.8252551020408163, 0.8290816326530612, 0.8341836734693877, 0.8354591836734694, 0.8354591836734694, 0.8380102040816326, 0.8418367346938775, 0.8456632653061225, 0.8443877551020408, 0.8482142857142857, 0.8533163265306123, 0.8545918367346939, 0.8571428571428572, 0.8571428571428572, 0.860969387755102, 0.8622448979591837, 0.8635204081632653, 0.8635204081632653, 0.8635204081632653, 0.8660714285714286, 0.8698979591836735, 0.8698979591836735, 0.8724489795918368, 0.8737244897959184, 0.8737244897959184, 0.8762755102040817, 0.8762755102040817, 0.8762755102040817, 0.8775510204081632, 0.8788265306122449, 0.8788265306122449, 0.8813775510204082, 0.8813775510204082, 0.8788265306122449, 0.8788265306122449, 0.8788265306122449, 0.8864795918367347, 0.8864795918367347, 0.889030612244898, 0.8864795918367347, 0.8954081632653061, 0.8941326530612245, 0.8992346938775511, 0.9043367346938775, 0.9043367346938775, 0.9056122448979592, 0.9068877551020408, 0.9081632653061225, 0.909438775510204, 0.9119897959183674, 0.9145408163265306, 0.9145408163265306, 0.9158163265306123, 0.9209183673469388, 0.923469387755102, 0.923469387755102, 0.9260204081632653, 0.9285714285714286, 0.9311224489795918, 0.9298469387755102, 0.9375, 0.9349489795918368, 0.9387755102040817, 0.9362244897959183, 0.9375, 0.9362244897959183, 0.9375, 0.9375, 0.9387755102040817, 0.9400510204081632, 0.9387755102040817, 0.9426020408163265, 0.9413265306122449, 0.9426020408163265, 0.9451530612244898, 0.9451530612244898, 0.9451530612244898, 0.9477040816326531, 0.9464285714285714, 0.9489795918367347, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.951530612244898, 0.9528061224489796, 0.9502551020408163, 0.9528061224489796, 0.9566326530612245, 0.9579081632653061, 0.9591836734693877, 0.9642857142857143, 0.9642857142857143, 0.965561224489796, 0.965561224489796, 0.965561224489796, 0.965561224489796, 0.965561224489796, 0.965561224489796, 0.9668367346938775, 0.9706632653061225, 0.9706632653061225, 0.9719387755102041, 0.9732142857142857, 0.9732142857142857, 0.9757653061224489, 0.9770408163265306, 0.9770408163265306, 0.9770408163265306, 0.9783163265306123, 0.9783163265306123, 0.9783163265306123, 0.9757653061224489, 0.9808673469387755, 0.9808673469387755, 0.9757653061224489, 0.9821428571428571, 0.9795918367346939, 0.9834183673469388, 0.9834183673469388, 0.9808673469387755, 0.9821428571428571, 0.9846938775510204, 0.9846938775510204, 0.9846938775510204, 0.9846938775510204, 0.9834183673469388, 0.9834183673469388, 0.9834183673469388, 0.9834183673469388, 0.9846938775510204, 0.985969387755102, 0.9872448979591837, 0.9910714285714286, 0.9910714285714286, 0.9910714285714286, 0.9936224489795918, 0.9948979591836735, 0.9961734693877551, 0.9987244897959183, 1.0] and test accuracy = [0.8303333333333334, 0.827, 0.8276666666666667, 0.8275, 0.8276666666666667, 0.8278333333333333, 0.8278333333333333, 0.8281666666666667, 0.8285, 0.828, 0.8281666666666667, 0.8281666666666667, 0.8283333333333334, 0.8285, 0.8285, 0.8285, 0.8288333333333333, 0.8291666666666667, 0.8288333333333333, 0.8288333333333333, 0.8288333333333333, 0.8286666666666667, 0.829, 0.8293333333333334, 0.8293333333333334, 0.8293333333333334, 0.8295, 0.8291666666666667, 0.8291666666666667, 0.829, 0.8291666666666667, 0.8295, 0.8293333333333334, 0.8291666666666667, 0.829, 0.829, 0.8295, 0.8295, 0.8293333333333334, 0.8293333333333334, 0.8295, 0.8295, 0.8298333333333333, 0.8306666666666667, 0.8305, 0.8308333333333333, 0.8308333333333333, 0.8301666666666667, 0.8301666666666667, 0.8295, 0.8293333333333334, 0.8301666666666667, 0.83, 0.8303333333333334, 0.8293333333333334, 0.8298333333333333, 0.8301666666666667, 0.8301666666666667, 0.8291666666666667, 0.8303333333333334, 0.829, 0.8298333333333333, 0.8273333333333334, 0.825, 0.823, 0.8198333333333333, 0.8178333333333333, 0.8143333333333334, 0.813, 0.812, 0.8111666666666667, 0.8093333333333333, 0.8083333333333333, 0.808, 0.808, 0.8081666666666667, 0.8078333333333333, 0.808, 0.8076666666666666, 0.8075, 0.8061666666666667, 0.8065, 0.8065, 0.806, 0.8055, 0.806, 0.8056666666666666, 0.8041666666666667, 0.8033333333333333, 0.8023333333333333, 0.802, 0.8011666666666667, 0.8006666666666666, 0.8006666666666666, 0.8, 0.7993333333333333, 0.7983333333333333, 0.7985, 0.7985, 0.7976666666666666, 0.7963333333333333, 0.7958333333333333, 0.7948333333333333, 0.7936666666666666, 0.7928333333333333, 0.7926666666666666, 0.7926666666666666, 0.7923333333333333, 0.7918333333333333, 0.791, 0.7906666666666666, 0.7903333333333333, 0.79, 0.7898333333333334, 0.7893333333333333, 0.7891666666666667, 0.7893333333333333, 0.7888333333333334, 0.7886666666666666, 0.7875, 0.7873333333333333, 0.7873333333333333, 0.7868333333333334, 0.7868333333333334, 0.7865, 0.7851666666666667, 0.7845, 0.784, 0.7836666666666666, 0.7826666666666666, 0.7826666666666666, 0.7825, 0.7821666666666667, 0.7813333333333333, 0.781, 0.7805, 0.7796666666666666, 0.7793333333333333, 0.779, 0.7786666666666666, 0.778, 0.7771666666666667, 0.7765, 0.7761666666666667, 0.7758333333333334, 0.7746666666666666, 0.7738333333333334, 0.7725, 0.7718333333333334, 0.7711666666666667, 0.7711666666666667, 0.7708333333333334, 0.771, 0.7715, 0.7713333333333333, 0.7703333333333333, 0.7695, 0.769, 0.7688333333333334, 0.7683333333333333, 0.7675, 0.7665, 0.7663333333333333, 0.765, 0.763, 0.7628333333333334, 0.7623333333333333, 0.7618333333333334, 0.7603333333333333, 0.7596666666666667, 0.7595, 0.7593333333333333, 0.7585, 0.7581666666666667, 0.7571666666666667, 0.757, 0.7565, 0.7563333333333333, 0.756, 0.7543333333333333, 0.7536666666666667, 0.7536666666666667, 0.7536666666666667, 0.7525, 0.7525, 0.7518333333333334, 0.7508333333333334, 0.7503333333333333, 0.7491666666666666, 0.748, 0.7465, 0.7453333333333333, 0.7445, 0.7433333333333333, 0.7428333333333333, 0.7411666666666666, 0.7411666666666666, 0.7393333333333333, 0.7368333333333333, 0.7355, 0.7341666666666666, 0.7333333333333333, 0.7325, 0.7318333333333333, 0.7295, 0.729, 0.7286666666666667, 0.7265, 0.7245, 0.7223333333333334, 0.7205, 0.7186666666666667, 0.7168333333333333, 0.715, 0.7135, 0.7128333333333333, 0.7113333333333334, 0.7086666666666667, 0.707, 0.7043333333333334, 0.7028333333333333, 0.7008333333333333, 0.6995, 0.6936666666666667, 0.6926666666666667, 0.6913333333333334, 0.6895, 0.687, 0.6866666666666666, 0.6851666666666667, 0.685, 0.6836666666666666, 0.6813333333333333, 0.6801666666666667, 0.6766666666666666, 0.6725, 0.6691666666666667, 0.6665, 0.663, 0.6596666666666666, 0.6556666666666666, 0.6513333333333333, 0.6458333333333334, 0.6418333333333334, 0.6363333333333333, 0.6345, 0.6323333333333333, 0.6248333333333334, 0.6186666666666667, 0.6105, 0.6021666666666666, 0.5908333333333333, 0.5853333333333334, 0.5755, 0.5675, 0.559, 0.555, 0.5476666666666666, 0.5408333333333334, 0.536, 0.5318333333333334, 0.5281666666666667, 0.5208333333333334, 0.51, 0.49616666666666664, 0.4846666666666667, 0.47683333333333333, 0.4675, 0.45616666666666666, 0.442, 0.43033333333333335, 0.414, 0.35283333333333333, 0.3338333333333333, 0.32033333333333336, 0.31216666666666665, 0.3016666666666667, 0.2911666666666667, 0.2886666666666667, 0.2866666666666667, 0.2841666666666667, 0.2813333333333333, 0.2801666666666667, 0.2778333333333333, 0.2763333333333333, 0.2748333333333333, 0.275, 0.2748333333333333, 0.2753333333333333, 0.2765, 0.23016666666666666, 0.1925, 0.19316666666666665, 0.1925, 0.19216666666666668, 0.19033333333333333, 0.19083333333333333, 0.18883333333333333, 0.18616666666666667, 0.18083333333333335, 0.18083333333333335, 0.181, 0.181, 0.18116666666666667, 0.18083333333333335, 0.18083333333333335, 0.17983333333333335, 0.17933333333333334, 0.179, 0.09566666666666666]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/LassoNet/rep_3/FMNIST_LassoNet_rep3_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 3\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on FMNIST (repetition 3)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.8928 - accuracy: 0.2098\n",
      "test acc for vanilla model is [1.8927700519561768, 0.20983333885669708]\n",
      "1.8927700519561768 0.20983333885669708\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.2780 - accuracy: 0.4907\n",
      "test acc for vanilla model is [1.2779649496078491, 0.4906666576862335]\n",
      "1.2779649496078491 0.4906666576862335\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.3067 - accuracy: 0.5163\n",
      "test acc for vanilla model is [1.306650996208191, 0.5163333415985107]\n",
      "1.306650996208191 0.5163333415985107\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.2949 - accuracy: 0.6892\n",
      "test acc for vanilla model is [1.2949368953704834, 0.6891666650772095]\n",
      "1.2949368953704834 0.6891666650772095\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0124 - accuracy: 0.8135\n",
      "test acc for vanilla model is [1.012393593788147, 0.8134999871253967]\n",
      "1.012393593788147 0.8134999871253967\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0562 - accuracy: 0.8183\n",
      "test acc for vanilla model is [1.056201457977295, 0.8183333277702332]\n",
      "1.056201457977295 0.8183333277702332\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0660 - accuracy: 0.8185\n",
      "test acc for vanilla model is [1.0659924745559692, 0.8184999823570251]\n",
      "1.0659924745559692 0.8184999823570251\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0528 - accuracy: 0.8183\n",
      "test acc for vanilla model is [1.0527647733688354, 0.8183333277702332]\n",
      "1.0527647733688354 0.8183333277702332\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0075 - accuracy: 0.8220\n",
      "test acc for vanilla model is [1.007523536682129, 0.8220000267028809]\n",
      "1.007523536682129 0.8220000267028809\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0620 - accuracy: 0.8132\n",
      "test acc for vanilla model is [1.0620142221450806, 0.8131666779518127]\n",
      "1.0620142221450806 0.8131666779518127\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0656 - accuracy: 0.8168\n",
      "test acc for vanilla model is [1.0656169652938843, 0.8168333172798157]\n",
      "1.0656169652938843 0.8168333172798157\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0596 - accuracy: 0.8198\n",
      "test acc for vanilla model is [1.0596377849578857, 0.8198333382606506]\n",
      "1.0596377849578857 0.8198333382606506\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0208 - accuracy: 0.8242\n",
      "test acc for vanilla model is [1.0207977294921875, 0.8241666555404663]\n",
      "1.0207977294921875 0.8241666555404663\n",
      "Repetition 3: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.20983333885669708, 0.4906666576862335, 0.5163333415985107, 0.6891666650772095, 0.8134999871253967, 0.8183333277702332, 0.8184999823570251, 0.8183333277702332, 0.8220000267028809, 0.8131666779518127, 0.8168333172798157, 0.8198333382606506, 0.8241666555404663]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_dnn/rep_3/FMNIST_HSIC_dnn_rep3_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 3\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on FMNIST (repetition 3)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 3: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.206, 0.4673333333333333, 0.496, 0.7031666666666667, 0.822, 0.8285, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_svm/rep_3/FMNIST_HSIC_svm_rep3_res.csv\n",
      "\n",
      "Starting repetition 4/5\n",
      "Loading dataset: FMNIST for repetition 4/5\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 4\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on FMNIST (repetition 4)\n",
      "Repetition 4: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.010204081632653073, 0.036989795918367374, 0.09438775510204078, 0.16326530612244894, 0.23341836734693877, 0.2869897959183674, 0.3303571428571429, 0.35459183673469385, 0.3966836734693877, 0.40816326530612246, 0.41836734693877553, 0.4272959183673469, 0.4336734693877551, 0.4426020408163265, 0.45408163265306123, 0.45663265306122447, 0.45790816326530615, 0.46683673469387754, 0.4732142857142857, 0.4821428571428571, 0.48852040816326525, 0.4961734693877551, 0.5051020408163265, 0.5178571428571428, 0.5267857142857143, 0.5318877551020409, 0.5408163265306123, 0.5484693877551021, 0.5510204081632653, 0.5561224489795918, 0.5586734693877551, 0.5625, 0.5688775510204082, 0.5739795918367347, 0.5727040816326531, 0.5841836734693877, 0.5841836734693877, 0.5854591836734694, 0.596938775510204, 0.6033163265306123, 0.6084183673469388, 0.6135204081632653, 0.6198979591836735, 0.6352040816326531, 0.6377551020408163, 0.6479591836734694, 0.6479591836734694, 0.6581632653061225, 0.6517857142857143, 0.6645408163265306, 0.6645408163265306, 0.6734693877551021, 0.6709183673469388, 0.6798469387755102, 0.6760204081632653, 0.6836734693877551, 0.6849489795918368, 0.6938775510204082, 0.6887755102040816, 0.7028061224489797, 0.6977040816326531, 0.7091836734693877, 0.7079081632653061, 0.7104591836734694, 0.7117346938775511, 0.7193877551020409, 0.7193877551020409, 0.7232142857142857, 0.7206632653061225, 0.7244897959183674, 0.7257653061224489, 0.7295918367346939, 0.7295918367346939, 0.7397959183673469, 0.7359693877551021, 0.7474489795918368, 0.7410714285714286, 0.7538265306122449, 0.7551020408163265, 0.7665816326530612, 0.7627551020408163, 0.7716836734693877, 0.7665816326530612, 0.7806122448979592, 0.778061224489796, 0.7882653061224489, 0.7869897959183674, 0.798469387755102, 0.7933673469387755, 0.8061224489795918, 0.8010204081632653, 0.8112244897959184, 0.8035714285714286, 0.8150510204081632, 0.8086734693877551, 0.8201530612244898, 0.8137755102040816, 0.8188775510204082, 0.8201530612244898, 0.8188775510204082, 0.8201530612244898, 0.8227040816326531, 0.8227040816326531, 0.8227040816326531, 0.8227040816326531, 0.8214285714285714, 0.8239795918367347, 0.8278061224489796, 0.8290816326530612, 0.8329081632653061, 0.8354591836734694, 0.8329081632653061, 0.8329081632653061, 0.8367346938775511, 0.8380102040816326, 0.8456632653061225, 0.8507653061224489, 0.8520408163265306, 0.8545918367346939, 0.8558673469387755, 0.8571428571428572, 0.8584183673469388, 0.8622448979591837, 0.8673469387755102, 0.8711734693877551, 0.8724489795918368, 0.8737244897959184, 0.875, 0.875, 0.8737244897959184, 0.875, 0.875, 0.8775510204081632, 0.8788265306122449, 0.8801020408163265, 0.8801020408163265, 0.8826530612244898, 0.8839285714285714, 0.8877551020408163, 0.8903061224489796, 0.8903061224489796, 0.889030612244898, 0.8928571428571429, 0.8928571428571429, 0.8941326530612245, 0.8928571428571429, 0.8979591836734694, 0.8966836734693877, 0.8992346938775511, 0.8992346938775511, 0.9017857142857143, 0.9043367346938775, 0.9068877551020408, 0.9056122448979592, 0.9081632653061225, 0.9107142857142857, 0.909438775510204, 0.9107142857142857, 0.9107142857142857, 0.9107142857142857, 0.9119897959183674, 0.9132653061224489, 0.9145408163265306, 0.9107142857142857, 0.9158163265306123, 0.9107142857142857, 0.9145408163265306, 0.9145408163265306, 0.9158163265306123, 0.9170918367346939, 0.9209183673469388, 0.9209183673469388, 0.9209183673469388, 0.9260204081632653, 0.9272959183673469, 0.9272959183673469, 0.9285714285714286, 0.9311224489795918, 0.9336734693877551, 0.9362244897959183, 0.9375, 0.9387755102040817, 0.9413265306122449, 0.9400510204081632, 0.9400510204081632, 0.9426020408163265, 0.9451530612244898, 0.9477040816326531, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9502551020408163, 0.9528061224489796, 0.9553571428571429, 0.9553571428571429, 0.9566326530612245, 0.9540816326530612, 0.9540816326530612, 0.9553571428571429, 0.9553571428571429, 0.9579081632653061, 0.9604591836734694, 0.9642857142857143, 0.9668367346938775, 0.9681122448979592, 0.965561224489796, 0.9681122448979592, 0.9719387755102041, 0.9668367346938775, 0.9744897959183674, 0.9732142857142857, 0.9744897959183674, 0.9744897959183674, 0.9744897959183674, 0.9757653061224489, 0.9783163265306123, 0.9770408163265306, 0.9770408163265306, 0.9770408163265306, 0.9783163265306123, 0.9783163265306123, 0.9783163265306123, 0.9783163265306123, 0.9795918367346939, 0.9795918367346939, 0.9808673469387755, 0.9783163265306123, 0.9821428571428571, 0.9808673469387755, 0.9808673469387755, 0.9834183673469388, 0.9834183673469388, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.9872448979591837, 0.9872448979591837, 0.9872448979591837, 0.9885204081632653, 0.9910714285714286, 0.9936224489795918, 0.9961734693877551, 0.9987244897959183, 1.0] and test accuracy = [0.8243333333333334, 0.8211666666666667, 0.8225, 0.8221666666666667, 0.8228333333333333, 0.8223333333333334, 0.8226666666666667, 0.8228333333333333, 0.8231666666666667, 0.8238333333333333, 0.8238333333333333, 0.8241666666666667, 0.8245, 0.8248333333333333, 0.8246666666666667, 0.8245, 0.8246666666666667, 0.8246666666666667, 0.8246666666666667, 0.8245, 0.8243333333333334, 0.8243333333333334, 0.8243333333333334, 0.8245, 0.8245, 0.8246666666666667, 0.8245, 0.8248333333333333, 0.8248333333333333, 0.8248333333333333, 0.825, 0.8246666666666667, 0.8248333333333333, 0.825, 0.825, 0.8246666666666667, 0.8245, 0.8243333333333334, 0.8243333333333334, 0.8246666666666667, 0.8253333333333334, 0.8253333333333334, 0.8251666666666667, 0.8251666666666667, 0.8251666666666667, 0.8253333333333334, 0.8258333333333333, 0.8263333333333334, 0.8261666666666667, 0.8265, 0.8266666666666667, 0.8273333333333334, 0.8268333333333333, 0.827, 0.8275, 0.8271666666666667, 0.8268333333333333, 0.8266666666666667, 0.8276666666666667, 0.8281666666666667, 0.8295, 0.8288333333333333, 0.8268333333333333, 0.823, 0.8196666666666667, 0.8191666666666667, 0.8148333333333333, 0.8133333333333334, 0.813, 0.8123333333333334, 0.8121666666666667, 0.8118333333333333, 0.8118333333333333, 0.8115, 0.811, 0.811, 0.8103333333333333, 0.8091666666666667, 0.8091666666666667, 0.8083333333333333, 0.8083333333333333, 0.8081666666666667, 0.8081666666666667, 0.8076666666666666, 0.808, 0.8071666666666667, 0.807, 0.8063333333333333, 0.8058333333333333, 0.8058333333333333, 0.8055, 0.8053333333333333, 0.8051666666666667, 0.8041666666666667, 0.8036666666666666, 0.8031666666666667, 0.803, 0.8021666666666667, 0.8015, 0.8011666666666667, 0.8006666666666666, 0.8, 0.7998333333333333, 0.7998333333333333, 0.799, 0.7988333333333333, 0.7986666666666666, 0.7976666666666666, 0.797, 0.7965, 0.796, 0.796, 0.7958333333333333, 0.7948333333333333, 0.7941666666666667, 0.7936666666666666, 0.7933333333333333, 0.7926666666666666, 0.7928333333333333, 0.792, 0.7916666666666666, 0.7908333333333334, 0.7913333333333333, 0.7906666666666666, 0.7908333333333334, 0.79, 0.79, 0.789, 0.7888333333333334, 0.7876666666666666, 0.7876666666666666, 0.7873333333333333, 0.787, 0.7858333333333334, 0.7863333333333333, 0.7846666666666666, 0.7843333333333333, 0.7831666666666667, 0.7828333333333334, 0.7826666666666666, 0.7831666666666667, 0.7826666666666666, 0.7826666666666666, 0.7805, 0.7801666666666667, 0.7798333333333334, 0.7791666666666667, 0.7798333333333334, 0.7795, 0.7796666666666666, 0.7793333333333333, 0.779, 0.7785, 0.7783333333333333, 0.7776666666666666, 0.7766666666666666, 0.7768333333333334, 0.7761666666666667, 0.7756666666666666, 0.7751666666666667, 0.7748333333333334, 0.7731666666666667, 0.7715, 0.7706666666666667, 0.7705, 0.77, 0.7698333333333334, 0.7691666666666667, 0.7686666666666667, 0.7683333333333333, 0.7666666666666667, 0.7653333333333333, 0.7648333333333334, 0.7643333333333333, 0.7631666666666667, 0.7623333333333333, 0.7613333333333333, 0.7608333333333334, 0.7601666666666667, 0.7601666666666667, 0.7595, 0.7583333333333333, 0.757, 0.7555, 0.7545, 0.7538333333333334, 0.7543333333333333, 0.7538333333333334, 0.7538333333333334, 0.7528333333333334, 0.7516666666666667, 0.7516666666666667, 0.7505, 0.7491666666666666, 0.7486666666666667, 0.7473333333333333, 0.7458333333333333, 0.7443333333333333, 0.7431666666666666, 0.7421666666666666, 0.74, 0.7381666666666666, 0.7371666666666666, 0.7341666666666666, 0.7338333333333333, 0.7325, 0.7308333333333333, 0.7278333333333333, 0.7266666666666667, 0.7253333333333334, 0.7241666666666666, 0.722, 0.72, 0.718, 0.7171666666666666, 0.714, 0.7118333333333333, 0.71, 0.7075, 0.7048333333333333, 0.7013333333333334, 0.6991666666666667, 0.6966666666666667, 0.6888333333333333, 0.6863333333333334, 0.6818333333333333, 0.6773333333333333, 0.6715, 0.67, 0.6676666666666666, 0.665, 0.6626666666666666, 0.6601666666666667, 0.6578333333333334, 0.6556666666666666, 0.6511666666666667, 0.6473333333333333, 0.6443333333333333, 0.6395, 0.634, 0.6293333333333333, 0.6251666666666666, 0.6186666666666667, 0.6171666666666666, 0.6131666666666666, 0.6098333333333333, 0.6031666666666666, 0.5993333333333334, 0.5918333333333333, 0.5796666666666667, 0.568, 0.5581666666666667, 0.545, 0.5361666666666667, 0.525, 0.514, 0.5055, 0.49866666666666665, 0.491, 0.48383333333333334, 0.473, 0.44333333333333336, 0.40016666666666667, 0.39166666666666666, 0.387, 0.38266666666666665, 0.377, 0.37283333333333335, 0.368, 0.36383333333333334, 0.36, 0.356, 0.35383333333333333, 0.35083333333333333, 0.34833333333333333, 0.3456666666666667, 0.341, 0.3378333333333333, 0.33116666666666666, 0.32083333333333336, 0.3155, 0.30966666666666665, 0.3085, 0.30433333333333334, 0.30033333333333334, 0.299, 0.2966666666666667, 0.2951666666666667, 0.2895, 0.28583333333333333, 0.27466666666666667, 0.265, 0.22566666666666665, 0.20883333333333334, 0.2005, 0.18983333333333333, 0.17583333333333334, 0.1755, 0.17416666666666666, 0.1735, 0.1725, 0.17183333333333334, 0.17016666666666666, 0.16683333333333333, 0.15933333333333333, 0.145, 0.09566666666666666]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/LassoNet/rep_4/FMNIST_LassoNet_rep4_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 4\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on FMNIST (repetition 4)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.8935 - accuracy: 0.2102\n",
      "test acc for vanilla model is [1.893547534942627, 0.21016666293144226]\n",
      "1.893547534942627 0.21016666293144226\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.2803 - accuracy: 0.4747\n",
      "test acc for vanilla model is [1.28031587600708, 0.47466665506362915]\n",
      "1.28031587600708 0.47466665506362915\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.3063 - accuracy: 0.5115\n",
      "test acc for vanilla model is [1.3062816858291626, 0.5115000009536743]\n",
      "1.3062816858291626 0.5115000009536743\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.2950 - accuracy: 0.6902\n",
      "test acc for vanilla model is [1.2950242757797241, 0.6901666522026062]\n",
      "1.2950242757797241 0.6901666522026062\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0176 - accuracy: 0.8145\n",
      "test acc for vanilla model is [1.0176457166671753, 0.8144999742507935]\n",
      "1.0176457166671753 0.8144999742507935\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0523 - accuracy: 0.8183\n",
      "test acc for vanilla model is [1.0522758960723877, 0.8183333277702332]\n",
      "1.0522758960723877 0.8183333277702332\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0693 - accuracy: 0.8180\n",
      "test acc for vanilla model is [1.0692574977874756, 0.8180000185966492]\n",
      "1.0692574977874756 0.8180000185966492\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0497 - accuracy: 0.8165\n",
      "test acc for vanilla model is [1.049741506576538, 0.8165000081062317]\n",
      "1.049741506576538 0.8165000081062317\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0095 - accuracy: 0.8213\n",
      "test acc for vanilla model is [1.0095041990280151, 0.8213333487510681]\n",
      "1.0095041990280151 0.8213333487510681\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0616 - accuracy: 0.8145\n",
      "test acc for vanilla model is [1.061554193496704, 0.8144999742507935]\n",
      "1.061554193496704 0.8144999742507935\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0653 - accuracy: 0.8187\n",
      "test acc for vanilla model is [1.0652586221694946, 0.8186666369438171]\n",
      "1.0652586221694946 0.8186666369438171\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0552 - accuracy: 0.8203\n",
      "test acc for vanilla model is [1.0551810264587402, 0.8203333616256714]\n",
      "1.0551810264587402 0.8203333616256714\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0182 - accuracy: 0.8245\n",
      "test acc for vanilla model is [1.0181609392166138, 0.8245000243186951]\n",
      "1.0181609392166138 0.8245000243186951\n",
      "Repetition 4: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.21016666293144226, 0.47466665506362915, 0.5115000009536743, 0.6901666522026062, 0.8144999742507935, 0.8183333277702332, 0.8180000185966492, 0.8165000081062317, 0.8213333487510681, 0.8144999742507935, 0.8186666369438171, 0.8203333616256714, 0.8245000243186951]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_dnn/rep_4/FMNIST_HSIC_dnn_rep4_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 4\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on FMNIST (repetition 4)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 4: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.206, 0.4673333333333333, 0.496, 0.7031666666666667, 0.822, 0.8285, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_svm/rep_4/FMNIST_HSIC_svm_rep4_res.csv\n",
      "\n",
      "Starting repetition 5/5\n",
      "Loading dataset: FMNIST for repetition 5/5\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 5\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running LassoNet on FMNIST (repetition 5)\n",
      "Repetition 5: sparsity = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007653061224489832, 0.02168367346938771, 0.06632653061224492, 0.13010204081632648, 0.20663265306122447, 0.2716836734693877, 0.3227040816326531, 0.3520408163265306, 0.375, 0.3852040816326531, 0.3979591836734694, 0.40816326530612246, 0.41709183673469385, 0.4272959183673469, 0.4362244897959183, 0.4477040816326531, 0.45663265306122447, 0.4604591836734694, 0.46556122448979587, 0.4770408163265306, 0.48341836734693877, 0.4910714285714286, 0.5012755102040816, 0.5012755102040816, 0.5089285714285714, 0.5127551020408163, 0.5267857142857143, 0.5306122448979591, 0.5395408163265306, 0.5471938775510203, 0.5484693877551021, 0.5548469387755102, 0.5586734693877551, 0.5637755102040816, 0.5727040816326531, 0.5778061224489797, 0.5880102040816326, 0.5956632653061225, 0.6058673469387755, 0.6058673469387755, 0.6147959183673469, 0.6147959183673469, 0.6173469387755102, 0.6224489795918368, 0.6288265306122449, 0.6339285714285714, 0.6377551020408163, 0.6441326530612245, 0.6428571428571428, 0.6466836734693877, 0.6607142857142857, 0.6568877551020409, 0.6683673469387755, 0.6645408163265306, 0.6709183673469388, 0.6709183673469388, 0.6785714285714286, 0.6747448979591837, 0.6798469387755102, 0.6875, 0.6951530612244898, 0.6938775510204082, 0.6977040816326531, 0.6926020408163265, 0.6989795918367347, 0.6989795918367347, 0.7117346938775511, 0.7117346938775511, 0.7206632653061225, 0.721938775510204, 0.7346938775510203, 0.7321428571428572, 0.7436224489795918, 0.7385204081632653, 0.7461734693877551, 0.7410714285714286, 0.7512755102040816, 0.7436224489795918, 0.7551020408163265, 0.7436224489795918, 0.7589285714285714, 0.7512755102040816, 0.7691326530612245, 0.7602040816326531, 0.7729591836734694, 0.7653061224489796, 0.7818877551020408, 0.7716836734693877, 0.7895408163265306, 0.7767857142857143, 0.7959183673469388, 0.7831632653061225, 0.798469387755102, 0.7857142857142857, 0.7895408163265306, 0.7971938775510204, 0.8048469387755102, 0.8048469387755102, 0.8061224489795918, 0.8137755102040816, 0.8125, 0.8150510204081632, 0.8150510204081632, 0.8188775510204082, 0.8227040816326531, 0.8290816326530612, 0.8303571428571428, 0.8329081632653061, 0.8341836734693877, 0.8341836734693877, 0.8392857142857143, 0.840561224489796, 0.8431122448979592, 0.840561224489796, 0.8418367346938775, 0.8380102040816326, 0.8431122448979592, 0.8482142857142857, 0.8456632653061225, 0.8456632653061225, 0.8494897959183674, 0.8520408163265306, 0.8520408163265306, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.8571428571428572, 0.8584183673469388, 0.860969387755102, 0.8622448979591837, 0.8635204081632653, 0.8647959183673469, 0.8686224489795918, 0.8711734693877551, 0.8698979591836735, 0.8724489795918368, 0.8724489795918368, 0.8801020408163265, 0.8826530612244898, 0.8826530612244898, 0.8877551020408163, 0.8877551020408163, 0.8915816326530612, 0.8915816326530612, 0.8928571428571429, 0.8915816326530612, 0.8928571428571429, 0.8928571428571429, 0.8941326530612245, 0.8928571428571429, 0.8954081632653061, 0.8992346938775511, 0.8979591836734694, 0.8979591836734694, 0.8992346938775511, 0.9043367346938775, 0.9081632653061225, 0.9107142857142857, 0.9145408163265306, 0.9145408163265306, 0.9145408163265306, 0.9145408163265306, 0.9183673469387755, 0.9183673469387755, 0.9196428571428571, 0.923469387755102, 0.923469387755102, 0.9285714285714286, 0.9221938775510204, 0.9298469387755102, 0.9298469387755102, 0.9285714285714286, 0.9298469387755102, 0.9323979591836735, 0.9349489795918368, 0.9349489795918368, 0.9362244897959183, 0.9349489795918368, 0.9362244897959183, 0.9400510204081632, 0.9426020408163265, 0.9426020408163265, 0.9426020408163265, 0.9451530612244898, 0.9477040816326531, 0.9477040816326531, 0.9477040816326531, 0.9451530612244898, 0.9451530612244898, 0.9451530612244898, 0.9477040816326531, 0.9464285714285714, 0.951530612244898, 0.9540816326530612, 0.9553571428571429, 0.9566326530612245, 0.9540816326530612, 0.9553571428571429, 0.9579081632653061, 0.9566326530612245, 0.9604591836734694, 0.9579081632653061, 0.9604591836734694, 0.9604591836734694, 0.9642857142857143, 0.9642857142857143, 0.965561224489796, 0.965561224489796, 0.9668367346938775, 0.9668367346938775, 0.9693877551020408, 0.9719387755102041, 0.9732142857142857, 0.9668367346938775, 0.9693877551020408, 0.9795918367346939, 0.9719387755102041, 0.9795918367346939, 0.9732142857142857, 0.9795918367346939, 0.9757653061224489, 0.9757653061224489, 0.9757653061224489, 0.9795918367346939, 0.9770408163265306, 0.9808673469387755, 0.9821428571428571, 0.9821428571428571, 0.9795918367346939, 0.9821428571428571, 0.9821428571428571, 0.9821428571428571, 0.9846938775510204, 0.9846938775510204, 0.9834183673469388, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.985969387755102, 0.9872448979591837, 0.9885204081632653, 0.9936224489795918, 0.9961734693877551, 0.9961734693877551, 0.9961734693877551, 1.0] and test accuracy = [0.822, 0.819, 0.8195, 0.8191666666666667, 0.819, 0.8195, 0.8193333333333334, 0.8195, 0.8198333333333333, 0.82, 0.82, 0.8198333333333333, 0.8196666666666667, 0.819, 0.8191666666666667, 0.819, 0.819, 0.819, 0.8191666666666667, 0.8195, 0.8195, 0.8195, 0.8196666666666667, 0.82, 0.82, 0.8205, 0.8203333333333334, 0.8205, 0.8208333333333333, 0.8208333333333333, 0.821, 0.8211666666666667, 0.8211666666666667, 0.8215, 0.8215, 0.8213333333333334, 0.8216666666666667, 0.8218333333333333, 0.8216666666666667, 0.8218333333333333, 0.8218333333333333, 0.822, 0.8225, 0.8226666666666667, 0.8225, 0.8225, 0.8228333333333333, 0.8223333333333334, 0.8225, 0.8223333333333334, 0.8221666666666667, 0.822, 0.822, 0.8216666666666667, 0.8213333333333334, 0.822, 0.8215, 0.8233333333333334, 0.8236666666666667, 0.8246666666666667, 0.8253333333333334, 0.8258333333333333, 0.8256666666666667, 0.8245, 0.8213333333333334, 0.8203333333333334, 0.8166666666666667, 0.816, 0.8153333333333334, 0.8151666666666667, 0.8128333333333333, 0.8126666666666666, 0.8121666666666667, 0.8118333333333333, 0.8111666666666667, 0.8116666666666666, 0.8108333333333333, 0.8105, 0.8108333333333333, 0.8105, 0.8096666666666666, 0.8093333333333333, 0.8091666666666667, 0.8083333333333333, 0.8078333333333333, 0.8076666666666666, 0.8068333333333333, 0.8066666666666666, 0.8066666666666666, 0.8058333333333333, 0.8053333333333333, 0.8053333333333333, 0.805, 0.8043333333333333, 0.8036666666666666, 0.8038333333333333, 0.8028333333333333, 0.8025, 0.8023333333333333, 0.8021666666666667, 0.802, 0.8018333333333333, 0.8011666666666667, 0.8006666666666666, 0.801, 0.8001666666666667, 0.7998333333333333, 0.7995, 0.7991666666666667, 0.7985, 0.7975, 0.7971666666666667, 0.7975, 0.7965, 0.796, 0.795, 0.7935, 0.793, 0.7918333333333333, 0.7913333333333333, 0.7903333333333333, 0.7896666666666666, 0.7885, 0.7881666666666667, 0.7875, 0.787, 0.7868333333333334, 0.7863333333333333, 0.7865, 0.7853333333333333, 0.7846666666666666, 0.7838333333333334, 0.7831666666666667, 0.7826666666666666, 0.7821666666666667, 0.7816666666666666, 0.7805, 0.7806666666666666, 0.7805, 0.7801666666666667, 0.779, 0.7783333333333333, 0.7773333333333333, 0.777, 0.7758333333333334, 0.7753333333333333, 0.7745, 0.7738333333333334, 0.7731666666666667, 0.7733333333333333, 0.7728333333333334, 0.7728333333333334, 0.7723333333333333, 0.7718333333333334, 0.7708333333333334, 0.7705, 0.7695, 0.7691666666666667, 0.7688333333333334, 0.7681666666666667, 0.7673333333333333, 0.7665, 0.7658333333333334, 0.7656666666666667, 0.7648333333333334, 0.7645, 0.7631666666666667, 0.7631666666666667, 0.7628333333333334, 0.7625, 0.7616666666666667, 0.761, 0.76, 0.76, 0.7595, 0.7586666666666667, 0.758, 0.7578333333333334, 0.7575, 0.7578333333333334, 0.7571666666666667, 0.7556666666666667, 0.755, 0.7545, 0.7538333333333334, 0.753, 0.7525, 0.752, 0.7516666666666667, 0.75, 0.75, 0.749, 0.748, 0.7466666666666667, 0.7456666666666667, 0.7458333333333333, 0.744, 0.7415, 0.7401666666666666, 0.7386666666666667, 0.7365, 0.7358333333333333, 0.7336666666666667, 0.7325, 0.7308333333333333, 0.7288333333333333, 0.7276666666666667, 0.7265, 0.7245, 0.7221666666666666, 0.7208333333333333, 0.7188333333333333, 0.7185, 0.7166666666666667, 0.7158333333333333, 0.7141666666666666, 0.7135, 0.7131666666666666, 0.7108333333333333, 0.7083333333333334, 0.707, 0.7048333333333333, 0.7011666666666667, 0.6976666666666667, 0.6945, 0.69, 0.6878333333333333, 0.6858333333333333, 0.6848333333333333, 0.682, 0.6815, 0.6801666666666667, 0.6783333333333333, 0.6771666666666667, 0.6751666666666667, 0.6736666666666666, 0.6711666666666667, 0.6701666666666667, 0.6683333333333333, 0.6661666666666667, 0.6623333333333333, 0.6585, 0.6541666666666667, 0.6481666666666667, 0.6405, 0.6303333333333333, 0.6273333333333333, 0.6221666666666666, 0.6158333333333333, 0.6091666666666666, 0.6025, 0.5961666666666666, 0.588, 0.5823333333333334, 0.5801666666666667, 0.575, 0.5708333333333333, 0.567, 0.5605, 0.55, 0.5413333333333333, 0.5346666666666666, 0.5255, 0.5116666666666667, 0.48933333333333334, 0.44366666666666665, 0.4028333333333333, 0.386, 0.37, 0.3626666666666667, 0.35633333333333334, 0.35183333333333333, 0.3465, 0.3426666666666667, 0.3403333333333333, 0.3363333333333333, 0.3335, 0.3288333333333333, 0.323, 0.31683333333333336, 0.31133333333333335, 0.30783333333333335, 0.30516666666666664, 0.30133333333333334, 0.2985, 0.29683333333333334, 0.29583333333333334, 0.2951666666666667, 0.29033333333333333, 0.288, 0.2843333333333333, 0.273, 0.2555, 0.22616666666666665, 0.2115, 0.21166666666666667, 0.21116666666666667, 0.21066666666666667, 0.21216666666666667, 0.21216666666666667, 0.21283333333333335, 0.21366666666666667, 0.20983333333333334, 0.17466666666666666, 0.1745, 0.173, 0.1705, 0.12683333333333333, 0.09566666666666666]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/LassoNet/rep_5/FMNIST_LassoNet_rep5_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 5\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_dnn on FMNIST (repetition 5)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is nn\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 1\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.8930 - accuracy: 0.2090\n",
      "test acc for vanilla model is [1.8930317163467407, 0.20900000631809235]\n",
      "1.8930317163467407 0.20900000631809235\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 5\n",
      "188/188 [==============================] - 1s 1ms/step - loss: 1.2732 - accuracy: 0.4735\n",
      "test acc for vanilla model is [1.2732398509979248, 0.47350001335144043]\n",
      "1.2732398509979248 0.47350001335144043\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 10\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.3119 - accuracy: 0.5093\n",
      "test acc for vanilla model is [1.3118869066238403, 0.5093333125114441]\n",
      "1.3118869066238403 0.5093333125114441\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 20\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.2771 - accuracy: 0.6920\n",
      "test acc for vanilla model is [1.2771281003952026, 0.6919999718666077]\n",
      "1.2771281003952026 0.6919999718666077\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 105\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0196 - accuracy: 0.8140\n",
      "test acc for vanilla model is [1.0195657014846802, 0.8140000104904175]\n",
      "1.0195657014846802 0.8140000104904175\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 190\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0548 - accuracy: 0.8162\n",
      "test acc for vanilla model is [1.0548123121261597, 0.8161666393280029]\n",
      "1.0548123121261597 0.8161666393280029\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0700 - accuracy: 0.8168\n",
      "test acc for vanilla model is [1.0700392723083496, 0.8168333172798157]\n",
      "1.0700392723083496 0.8168333172798157\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0454 - accuracy: 0.8167\n",
      "test acc for vanilla model is [1.0453853607177734, 0.8166666626930237]\n",
      "1.0453853607177734 0.8166666626930237\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0116 - accuracy: 0.8247\n",
      "test acc for vanilla model is [1.011564016342163, 0.8246666789054871]\n",
      "1.011564016342163 0.8246666789054871\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0577 - accuracy: 0.8122\n",
      "test acc for vanilla model is [1.0576584339141846, 0.812166690826416]\n",
      "1.0576584339141846 0.812166690826416\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 1ms/step - loss: 1.0595 - accuracy: 0.8178\n",
      "test acc for vanilla model is [1.059504508972168, 0.8178333044052124]\n",
      "1.059504508972168 0.8178333044052124\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0582 - accuracy: 0.8210\n",
      "test acc for vanilla model is [1.0581660270690918, 0.8209999799728394]\n",
      "1.0581660270690918 0.8209999799728394\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Input size is 220\n",
      "188/188 [==============================] - 0s 2ms/step - loss: 1.0240 - accuracy: 0.8238\n",
      "test acc for vanilla model is [1.0239503383636475, 0.8238333463668823]\n",
      "1.0239503383636475 0.8238333463668823\n",
      "Repetition 5: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.20900000631809235, 0.47350001335144043, 0.5093333125114441, 0.6919999718666077, 0.8140000104904175, 0.8161666393280029, 0.8168333172798157, 0.8166666626930237, 0.8246666789054871, 0.812166690826416, 0.8178333044052124, 0.8209999799728394, 0.8238333463668823]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_dnn/rep_5/FMNIST_HSIC_dnn_rep5_res.csv\n",
      "Loading dataset: FMNIST with one_hot = False for repetition 5\n",
      "x_train shape: (60000, 784), y_train shape: (60000,)\n",
      "x_test shape: (10000, 784), y_test shape: (10000,)\n",
      "Running HSIC_svm on FMNIST (repetition 5)\n",
      "Sequence of features is [  1   5  10  20 105 190 275 360 444 529 614 699 784]\n",
      "Downstream model for HSIC is svm\n",
      "Current number of features: 1\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 5\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 10\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 20\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 105\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 190\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 275\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 360\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 444\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 529\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 614\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 699\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Current number of features: 784\n",
      "Block HSIC Lasso B = 20.\n",
      "M set to 3.\n",
      "Using Gaussian kernel for the features, Delta kernel for the outcomes.\n",
      "Repetition 5: sparsity = [0.9987244897959183, 0.9936224489795918, 0.9872448979591837, 0.9744897959183674, 0.8660714285714286, 0.7576530612244898, 0.6492346938775511, 0.5408163265306123, 0.4336734693877551, 0.3252551020408163, 0.21683673469387754, 0.10841836734693877, 0.0] and test accuracy = [0.206, 0.4673333333333333, 0.496, 0.7031666666666667, 0.822, 0.8285, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333, 0.8318333333333333]\n",
      "Results successfully saved to results/input_sparsity/FMNIST/HSIC_svm/rep_5/FMNIST_HSIC_svm_rep5_res.csv\n",
      "\n",
      "Combined results saved to results/input_sparsity/FMNIST/all_results_summary.csv\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "from models import hadamard_nn_depth_2, hadamard_nn_depth_3, hadamard_nn_depth_4, hsic_dnn, hsic_svm, lassoNet\n",
    "import tensorflow as tf\n",
    "import config_inputsparse_compar\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# Constants\n",
    "REPS = 5\n",
    "BASE_SEED = config_inputsparse_compar.SEED\n",
    "LENET_FILE_PATH = './results/input_sparsity/FMNIST/'\n",
    "\n",
    "def run_methods_on_dataset(dataset_name, load_func, results_dir, rep):\n",
    "    # Set seeds for this repetition\n",
    "    current_seed = BASE_SEED + rep\n",
    "    np.random.seed(current_seed)\n",
    "    random.seed(current_seed)\n",
    "    tf.random.set_seed(current_seed)\n",
    "    \n",
    "    print(f\"Loading dataset: {dataset_name} for repetition {rep + 1}/{REPS}\")\n",
    "    \n",
    "    for method_name, method_func, one_hot in [\n",
    "        (\"LassoNet\", lassoNet, False),\n",
    "        (\"HSIC_dnn\", hsic_dnn, False),\n",
    "        (\"HSIC_svm\", hsic_svm, False)\n",
    "    ]:\n",
    "        # Create method-specific directory\n",
    "        method_dir = os.path.join(results_dir, method_name, f'rep_{rep + 1}')\n",
    "        os.makedirs(method_dir, exist_ok=True)\n",
    "        \n",
    "        # Generate a unique filename for the method-dataset-repetition combination\n",
    "        result_filename = os.path.join(method_dir, f'{dataset_name}_{method_name}_rep{rep + 1}_res.csv')\n",
    "        \n",
    "        results = []\n",
    "        print(f\"Loading dataset: {dataset_name} with one_hot = {one_hot} for repetition {rep + 1}\")\n",
    "        (train_X, train_y), (test_X, test_y) = load_func(one_hot=one_hot)\n",
    "        \n",
    "        print(f\"Running {method_name} on {dataset_name} (repetition {rep + 1})\")\n",
    "        sparsity, accuracy, value_seq = method_func(train_X, train_y, test_X, test_y)\n",
    "        print(f'Repetition {rep + 1}: sparsity = {sparsity} and test accuracy = {accuracy}')\n",
    "        \n",
    "        for s, a, v in zip(sparsity, accuracy, value_seq):\n",
    "            result = {\n",
    "                \"method\": method_name,\n",
    "                \"dataset\": dataset_name,\n",
    "                \"repetition\": rep + 1,\n",
    "                \"sparsity\": s,\n",
    "                \"accuracy\": a,\n",
    "                \"value\": v,\n",
    "                \"seed\": current_seed\n",
    "            }\n",
    "            results.append(result)\n",
    "            \n",
    "        # Save the results\n",
    "        save_results_to_csv(results, result_filename)\n",
    "        print(f'Results successfully saved to {result_filename}')\n",
    "\n",
    "def save_results_to_csv(results, result_filename):\n",
    "    df = pd.DataFrame(results)\n",
    "    df.to_csv(result_filename, index=False)\n",
    "\n",
    "def main():\n",
    "    datasets = {\n",
    "        \"FMNIST\": load_fashion\n",
    "    }\n",
    "    \n",
    "    # Create base results directory with timestamp\n",
    "    base_results_dir = os.path.join('results', 'input_sparsity', 'FMNIST')\n",
    "    os.makedirs(base_results_dir, exist_ok=True)\n",
    "    \n",
    "    # Run experiments for each repetition\n",
    "    for rep in range(REPS):\n",
    "        print(f\"\\nStarting repetition {rep + 1}/{REPS}\")\n",
    "        for dataset_name, load_func in datasets.items():\n",
    "            run_methods_on_dataset(dataset_name, load_func, base_results_dir, rep)\n",
    "            \n",
    "    # Optionally, combine all results into a single summary file\n",
    "    combine_all_results(base_results_dir)\n",
    "\n",
    "def combine_all_results(base_results_dir):\n",
    "    \"\"\"Combines all individual result files into a single summary file\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for method in ['HSIC_dnn', 'HSIC_svm', 'LassoNet']:\n",
    "        method_dir = os.path.join(base_results_dir, method)\n",
    "        if not os.path.exists(method_dir):\n",
    "            continue\n",
    "            \n",
    "        for rep_dir in os.listdir(method_dir):\n",
    "            if not rep_dir.startswith('rep_'):\n",
    "                continue\n",
    "                \n",
    "            rep_path = os.path.join(method_dir, rep_dir)\n",
    "            for result_file in os.listdir(rep_path):\n",
    "                if result_file.endswith('_res.csv'):\n",
    "                    df = pd.read_csv(os.path.join(rep_path, result_file))\n",
    "                    all_results.append(df)\n",
    "    \n",
    "    if all_results:\n",
    "        combined_df = pd.concat(all_results, ignore_index=True)\n",
    "        summary_file = os.path.join(base_results_dir, 'all_results_summary.csv')\n",
    "        combined_df.to_csv(summary_file, index=False)\n",
    "        print(f\"\\nCombined results saved to {summary_file}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
