{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-03-28T17:22:59.133618Z",
     "iopub.status.busy": "2025-03-28T17:22:59.133196Z",
     "iopub.status.idle": "2025-03-28T17:23:01.557511Z",
     "shell.execute_reply": "2025-03-28T17:23:01.556867Z",
     "shell.execute_reply.started": "2025-03-28T17:22:59.133584Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TF_ENABLE_ONEDNN_OPTS\"] = \"0\"\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, cifar100, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "    \n",
    "\n",
    "#rm -rf /usr/local/lib/python3.9/dist-packages/lxml\n",
    "#rm -rf /usr/local/lib/python3.9/dist-packages/lxml-*.egg-info\n",
    "#rm -rf /usr/local/lib/python3.9/dist-packages/lxml-5.3.1.dist-info\n",
    "#pip install --no-cache-dir --upgrade --force-reinstall lxml\n",
    "\n",
    "# Fixing cannot import etrees from lxml error by removing all folders and freshly reinstalling\n",
    "os.system(\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-*.egg-info && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-5.3.1.dist-info && \"\n",
    "    \"pip install --no-cache-dir --upgrade --force-reinstall lxml\"\n",
    ")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrHadamardDenseV2, StrConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns, StructuredSparsityCallback\n",
    "from models       import hadamard_vgg16, str_hadamard_vgg16, vanilla_vgg16, StrHadamardLeNet5BN, LeNet5BN, HadamardLeNet5BN,\\\n",
    "                         StrHadamardLeNet300100, LeNet300100\n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                      threshold_model_weights, flatten_and_filter_weights, create_and_save_trajectory_plot\n",
    "\n",
    "################################################################################\n",
    "# Define VGG model architecture and training parameters\n",
    "\n",
    "# Hadamard\n",
    "DEPTH               = 3\n",
    "LA                  = 6e-3\n",
    "INIT_REST           = tf.keras.initializers.Ones() \n",
    "KERNEL_INITIALIZER  = tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = tf.keras.initializers.Ones()\n",
    "#USE_BIAS            = True\n",
    "#FACTORIZE_BIAS      = False\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'constant' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 200 #200\n",
    "INIT_LR            = 5e-4 if PRETRAIN_OPT == 'sgd' else 1e-2 # depth 2=0.2, depth3 3 = 0.4, depth 4 = 0.6\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.0    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Finetuning\n",
    "DO_FINETUNING      = False\n",
    "FINETUNE_OPT       = 'sgd' # sgd, adam\n",
    "FINE_SCHEDULE      = 'cosine' # piecewise, constant, cosine\n",
    "FINETUNE_EPOCHS    = 50     if FINETUNE_OPT == 'sgd' else 20\n",
    "FINE_LA            = 1 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "FINE_LR            = 7e-3   if FINETUNE_OPT == 'sgd' else 2e-4\n",
    "FINETUNE_ALPHA     = 1e-5   if FINETUNE_OPT == 'sgd' else 1e-5\n",
    "\n",
    "# Input settings\n",
    "CLASS_NUM          = 10\n",
    "IMG_ROWS, IMG_COLS = 28, 28\n",
    "IMG_CHANNELS       = 1\n",
    "\n",
    "# Misc\n",
    "PAT                = 1000\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 20\n",
    "FINE_GRACE         = 10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 2\n",
    "compression_grid   = [10,50,75,100,200,500,700,800,1000,\n",
    "                      1250,1500,1750,2000,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,\n",
    "                      15000,20000,25000,30000,40000,50000,75000,100000]\n",
    "\n",
    "# Deirectories and saving paths\n",
    "LENET_FILE_PATH = './results/lenet300100/mnist/'\n",
    "\n",
    "print('Defining configs successful!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-03-28T17:14:09.073769Z",
     "iopub.status.busy": "2025-03-28T17:14:09.072982Z",
     "iopub.status.idle": "2025-03-28T17:14:10.243883Z",
     "shell.execute_reply": "2025-03-28T17:14:10.243186Z",
     "shell.execute_reply.started": "2025-03-28T17:14:09.073733Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Load data\n",
    "(X_train, Y_train), (X_test, Y_test) = mnist.load_data()\n",
    "\n",
    "X_train = np.expand_dims(X_train, -1)\n",
    "X_test = np.expand_dims(X_test, -1)\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1,2))  \n",
    "    stds = np.std(dataset, axis=(0, 1,2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.1,shuffle = True)\n",
    "\n",
    "# Normalize the images to [0, 1]\n",
    "X_train = X_train / 255.0\n",
    "X_val = X_val / 255.0\n",
    "X_test = X_test / 255.0\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-03-28T15:28:48.888455Z",
     "iopub.status.busy": "2025-03-28T15:28:48.887734Z",
     "iopub.status.idle": "2025-03-28T15:28:53.421967Z",
     "shell.execute_reply": "2025-03-28T15:28:53.421063Z",
     "shell.execute_reply.started": "2025-03-28T15:28:48.888429Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [],
   "source": [
    "# Vanilla LeNet-300-100 on MNIST (test cell)\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_mnist'\n",
    "#DEPTH = DEPTH\n",
    "LA = LA #lambdas[0] #LA\n",
    "#print(f'Starting run with lambda={LA:.2e}')\n",
    "#INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "INIT = tf.keras.initializers.HeNormal()\n",
    "#INIT = tf.keras.initializers.HeUniform()\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep1-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "#custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "vanilla_lenet300100 = LeNet300100()\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "vanilla_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='sparse_categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(vanilla_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = vanilla_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_val, Y_val), callbacks=[print_lr_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = vanilla_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-03-23T16:00:59.232801Z",
     "iopub.status.busy": "2025-03-23T16:00:59.232411Z"
    }
   },
   "outputs": [],
   "source": [
    "# StrHadamard LeNet300100 on MNIST (test cell)\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'lenet300100_mnist' # 'wrn16-8'\n",
    "DEPTH = 2\n",
    "LA = 2e-2 #lambdas[0] #LA\n",
    "print(f'Starting run with lambda={LA:.2e}')\n",
    "INIT_LR = 5e-3\n",
    "#INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "MULTFAC_INIT = tf.keras.initializers.Ones()\n",
    "USE_BIAS=False\n",
    "FACTORIZE_BIAS=False\n",
    "BATCH_SIZE = 256\n",
    "EPOCHS = 200\n",
    "LR_SCHEDULE = 'cosine'\n",
    "################################################################################\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-vanilla-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "#if not os.path.exists(RUN_PATH):\n",
    "#    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "#hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "#                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "#                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "#hadamard_lenet5 = StrHadamardLeNet5BN(input_shape=(32,32,1), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "#                                   multfac_initializer=MULTFAC_INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "hadamard_lenet300100 = StrHadamardLeNet300100(input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM, depth=DEPTH, la=LA,\n",
    "                                   init_rest=MULTFAC_INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "               loss='sparse_categorical_crossentropy',\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(hadamard_lenet300100.summary())\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Training\n",
    "pre_hist = hadamard_lenet300100.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_val, Y_val), \n",
    "                               callbacks=[print_lr_cb, custom_sparsity_callback, terminate_nan_cb, early_abort_cb])\n",
    "\n",
    "# Evaluate after training\n",
    "pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n",
    "\n",
    "# Calculate sparsity and compression rate at last epoch\n",
    "pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "# Save pretraining trajectory\n",
    "pre_total, pre_epochs = process_sparsity_callback(hist = pre_hist, hadamard_cb = custom_sparsity_callback, lr_cb = print_lr_cb)\n",
    "#pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "#pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "#print(f'Callback results saved successfully to {RUN_PATH}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-03-28T18:35:43.082734Z",
     "iopub.status.busy": "2025-03-28T18:35:43.081889Z",
     "iopub.status.idle": "2025-03-28T18:57:44.969713Z",
     "shell.execute_reply": "2025-03-28T18:57:44.968355Z",
     "shell.execute_reply.started": "2025-03-28T18:35:43.082721Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "# Assume that the following are defined and imported:\n",
    "# StrHadamardLeNet300100, PrintLRCallback, StructuredSparsityCallback, TerminateBadRuns, \n",
    "# process_sparsity_callback, get_optimizer, \n",
    "# and that variables like INIT_LR, PRETRAIN_OPT, EPOCHS, LR_SCHEDULE, LENET_FILE_PATH, \n",
    "# SEED, PAT, RESTORE_WEIGHTS, SAVE_METRICS, VERBOSE, GRACE, MINACC, \n",
    "# X_train, Y_train, X_val, Y_val, X_test, Y_test, IMG_ROWS, IMG_COLS, IMG_CHANNELS, CLASS_NUM, \n",
    "# LR_DECAY_FACT, MOMENTUM, LARGE_LRSTART, WARMUP are defined.\n",
    "\n",
    "lambdas = [0, 5e-4, 3e-3, 1e-1, 2e-1, 3e-1, 4e-1, 5e-1, 6e-1, 7e-1, 8e-1] \n",
    "lambdas.reverse()\n",
    "depths = [2, 3, 4]\n",
    "\n",
    "for DEPTH in depths:\n",
    "    for LA in lambdas:\n",
    "        print(f'Starting run with lambda={LA:.2e} and DEPTH={DEPTH}')\n",
    "        MODEL = 'lenet300100_mnist'\n",
    "        # Reuse defined INIT_LR, EPOCHS, PRETRAIN_OPT, LR_SCHEDULE, etc.\n",
    "        BATCH_SIZE = 256\n",
    "        MULTFAC_INIT = tf.keras.initializers.Ones()\n",
    "        USE_BIAS = False\n",
    "        FACTORIZE_BIAS = False\n",
    "        \n",
    "        fmt_la = f\"{LA:.1e}\"\n",
    "        RUN_NAME = f\"{MODEL}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-bs{BATCH_SIZE}\"\n",
    "        RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "        if not os.path.exists(RUN_PATH):\n",
    "            os.makedirs(RUN_PATH)\n",
    "        \n",
    "        np.random.seed(SEED)\n",
    "        random.seed(SEED)\n",
    "        tf.random.set_seed(SEED)\n",
    "        \n",
    "        early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "        custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "        print_lr_cb = PrintLRCallback()\n",
    "        terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "        early_abort_cb = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "        \n",
    "        hadamard_lenet300100 = StrHadamardLeNet300100(\n",
    "            input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "            n_classes=CLASS_NUM,\n",
    "            depth=DEPTH,\n",
    "            la=LA,\n",
    "            init_rest=MULTFAC_INIT,\n",
    "            use_bias=USE_BIAS,\n",
    "            factorize_bias=FACTORIZE_BIAS\n",
    "        )\n",
    "        \n",
    "        optimizer = get_optimizer(\n",
    "            lr_schedule=LR_SCHEDULE,\n",
    "            init_lr=INIT_LR,\n",
    "            lr_decay_fact=LR_DECAY_FACT,\n",
    "            epochs=EPOCHS,\n",
    "            dat=X_train,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            opt=PRETRAIN_OPT,\n",
    "            momentum=MOMENTUM,\n",
    "            alpha=0,\n",
    "            large_lr_start=LARGE_LRSTART,\n",
    "            warmup=WARMUP\n",
    "        )\n",
    "        \n",
    "        hadamard_lenet300100.compile(optimizer=optimizer,\n",
    "                                     loss='sparse_categorical_crossentropy',\n",
    "                                     metrics=['accuracy'])\n",
    "        print(hadamard_lenet300100.summary())\n",
    "        \n",
    "        pre_hist = hadamard_lenet300100.fit(\n",
    "            x=X_train, y=Y_train,\n",
    "            epochs=EPOCHS,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            validation_data=(X_val, Y_val),\n",
    "            callbacks=[print_lr_cb, custom_sparsity_callback, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        \n",
    "        pretrain_loss, pretrain_acc = hadamard_lenet300100.evaluate(X_test, Y_test)\n",
    "        print('\\nTest loss', pretrain_loss)\n",
    "        print('Test accuracy', pretrain_acc)\n",
    "        \n",
    "        pre_total, pre_epochs = process_sparsity_callback(\n",
    "            hist=pre_hist,\n",
    "            hadamard_cb=custom_sparsity_callback,\n",
    "            lr_cb=print_lr_cb\n",
    "        )\n",
    "        pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "        pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "        print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "        \n",
    "        \n",
    "        # Construct and save plot training metrics trajectories\n",
    "        create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory.pdf')\n",
    "        print(\"Full callback DataFrame (pre_total):\")\n",
    "        print(pre_total)\n",
    "        # Optionally, print the detailed epoch-wise callback DataFrame\n",
    "        print(\"\\nEpoch-wise callback metrics (pre_epochs):\")\n",
    "        print(pre_epochs)\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
