{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-02T13:33:09.451808Z",
     "iopub.status.busy": "2025-05-02T13:33:09.450974Z",
     "iopub.status.idle": "2025-05-02T13:33:16.262599Z",
     "shell.execute_reply": "2025-05-02T13:33:16.261496Z",
     "shell.execute_reply.started": "2025-05-02T13:33:09.451741Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TF_ENABLE_ONEDNN_OPTS\"] = \"0\"\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "    \n",
    "# Fixing cannot import etrees from lxml error by removing all folders and freshly reinstalling\n",
    "os.system(\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-*.egg-info && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-5.3.1.dist-info && \"\n",
    "    \"pip install --no-cache-dir --upgrade --force-reinstall lxml && \"\n",
    "    \"pip install --force-reinstall -v protobuf==3.20.2\" \n",
    ")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns, StructuredSparsityCallback\n",
    "from models       import WideResNet, hadamard_resnet18, hadamard_resnet20, hadamard_WRN_16_4, hadamard_WRN_16_8, hadamard_WRN_28_10\n",
    "from models       import hadamard_vgg16, str_hadamard_vgg16, vanilla_vgg16, vanilla_vgg16_individual, vanilla_vgg19_individual,\\\n",
    "                         vanilla_resnet18, vanilla_resnet18_individual, vanilla_resnet18_reg\n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                      threshold_model_weights, flatten_and_filter_weights, create_and_save_trajectory_plot, get_flops_and_profile\n",
    "\n",
    "################################################################################\n",
    "# Define VGG model architecture and training parameters\n",
    "\n",
    "# Hadamard\n",
    "DEPTH               = 2\n",
    "LA                  = 3e-4\n",
    "INIT_TYPE           ='equivar' #vanilla, equivar, root, ones\n",
    "INIT_REST           = tf.keras.initializers.Ones() \n",
    "MINPROD             = 3e-3\n",
    "INIT                = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "KERNEL_INITIALIZER  = tf.keras.initializers.HeNormal() #TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "MULTFAC_INITIALIZER = tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "#FACTORIZE_BIAS      = False\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 250 #200\n",
    "INIT_LR            = 0.2 if PRETRAIN_OPT == 'sgd' else 1e-2 # depth 2=0.3, depth3 3 = 0.4, depth 4 = 0.6\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Finetuning\n",
    "DO_FINETUNING      = False\n",
    "FINETUNE_OPT       = 'sgd' # sgd, adam\n",
    "FINE_SCHEDULE      = 'cosine' # piecewise, constant, cosine\n",
    "FINETUNE_EPOCHS    = 50     if FINETUNE_OPT == 'sgd' else 20\n",
    "FINE_LA            = 1 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "FINE_LR            = 7e-3   if FINETUNE_OPT == 'sgd' else 2e-4\n",
    "FINETUNE_ALPHA     = 1e-5   if FINETUNE_OPT == 'sgd' else 1e-5\n",
    "\n",
    "# Input settings\n",
    "CLASS_NUM          = 100\n",
    "IMG_ROWS, IMG_COLS = 32, 32\n",
    "IMG_CHANNELS       = 3\n",
    "\n",
    "# Misc\n",
    "PAT                = 1000\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 20\n",
    "FINE_GRACE         = 10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 2\n",
    "compression_grid   = [10,50,75,100,200,500,700,800,1000,\n",
    "                      1250,1500,1750,2000,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,\n",
    "                      15000,20000,25000,30000,40000,50000,75000,100000]\n",
    "lambdas_all        = [0, 1e-6, 5e-6,  #0-2\n",
    "                       1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5, #3-11\n",
    "                       1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #12-20\n",
    "                       1e-3, 2e-3, 5e-3] #21-23\n",
    "lambdas_10          = [1e-1, 1e-2, 1e-3, 5e-4]\n",
    "lambdas             = lambdas_10\n",
    "\n",
    "# Deirectories and saving paths\n",
    "RESNET_FILE_PATH = './results/resnet18/cifar100/grid'\n",
    "\n",
    "print('Defining configs successful!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-05-02T13:18:57.648021Z",
     "iopub.status.busy": "2025-05-02T13:18:57.647098Z",
     "iopub.status.idle": "2025-05-02T13:19:40.499006Z",
     "shell.execute_reply": "2025-05-02T13:19:40.498214Z",
     "shell.execute_reply.started": "2025-05-02T13:18:57.647985Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Load data\n",
    "(X_train, Y_train), (X_test, Y_test) = cifar100.load_data()\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1, 2))  \n",
    "    stds = np.std(dataset, axis=(0, 1, 2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "# mean = [125.3, 123.0, 113.9], std  = [63.0,  62.1,  66.7]\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.05,shuffle = True)\n",
    "\n",
    "# Color preprocessing\n",
    "X_train, X_test, X_val = color_preprocessing(X_train, X_test, X_val)\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Pre-process labels\n",
    "encoder = OneHotEncoder()\n",
    "encoder.fit(Y_train)\n",
    "Y_train = encoder.transform(Y_train).toarray()\n",
    "Y_test = encoder.transform(Y_test).toarray()\n",
    "Y_val =  encoder.transform(Y_val).toarray()\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n",
    "\n",
    "# Define augmentation and data generators and fit to data\n",
    "print('Using real-time data augmentation.')\n",
    "aug_datagen = ImageDataGenerator(horizontal_flip=True,width_shift_range=0.125,height_shift_range=0.125,\n",
    "                                 rotation_range=15,fill_mode='reflect')\n",
    "aug_datagen.fit(X_train)\n",
    "trainflow_aug = aug_datagen.flow(X_train, Y_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "val_datagen = ImageDataGenerator()\n",
    "val_datagen.fit(X_val)\n",
    "val_flow = val_datagen.flow(X_val, Y_val, batch_size=BATCH_SIZE)\n",
    "steps_per_ep = len(X_train) / BATCH_SIZE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-04-08T20:07:58.284998Z",
     "iopub.status.busy": "2025-04-08T20:07:58.284725Z",
     "iopub.status.idle": "2025-04-08T20:58:20.965998Z",
     "shell.execute_reply": "2025-04-08T20:58:20.965216Z",
     "shell.execute_reply.started": "2025-04-08T20:07:58.284974Z"
    }
   },
   "outputs": [],
   "source": [
    "# Cifar ResNet18 with filter sparsity D=2\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'str_resnet18'\n",
    "DEPTH = 2\n",
    "#LA =  2e-4\n",
    "#print(f'Starting run with depth = {DEPTH} and lambda={LA:.2e}')\n",
    "KERNEL_INITIALIZER  = KERNEL_INITIALIZER #tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = MULTFAC_INITIALIZER #tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "FACTORIZE_BIAS      = False\n",
    "init_lr = 0.2\n",
    "VERBOSE=1\n",
    "GRACE=20\n",
    "EPOCHS = 200\n",
    "\n",
    "# Define list of lambdas to iterate over\n",
    "lambdas_10 = [0,8e-5,2e-4,7e-4,8e-4,9e-4,1e-3,1.5e-3,2e-3,2.5e-3,3e-3,3.3e-3,3.6e-3,4.5e-3,5e-3,5.5e-3,6.5e-3,7.1e-3,8e-3,9e-3,1.2e-2,1.5e-2,2.05e-2,2.1e-2,0.2]\n",
    "\n",
    "\n",
    "# List of initial learning rates to iterate over\n",
    "#print(lambdas_10)\n",
    "\n",
    "for LA in lambdas_10:\n",
    "    print(f'\\nStarting run with learning rate: {init_lr:.2e} and depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "    \n",
    "    # Update run naming and saving paths with current learning rate\n",
    "    fmt_la = f\"{LA:.1e}\"\n",
    "    fmt_fine_la = f\"{FINE_LA:.1e}\"\n",
    "    RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{init_lr:.1e}-ftune-{FINETUNE_OPT}-flr{FINE_LR:.1e}-fla-{fmt_fine_la}-feps{FINETUNE_EPOCHS}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "    RUN_PATH = os.path.join(RESNET_FILE_PATH, RUN_NAME)\n",
    "    \n",
    "    # Set seed for reproducibility in each run\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    # Define callbacks\n",
    "    early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "    custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "    print_lr_cb = PrintLRCallback()\n",
    "    terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "    early_abort_cb = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "    \n",
    "    # Define the model with the given parameters\n",
    "    hadamard_net = hadamard_resnet18(\n",
    "        depth=DEPTH,\n",
    "        kernel_initializer=KERNEL_INITIALIZER,\n",
    "        multfac_initializer=MULTFAC_INITIALIZER,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        la=LA,\n",
    "        n_classes=CLASS_NUM,\n",
    "        use_bias=USE_BIAS\n",
    "    )\n",
    "    \n",
    "    # Create pretrain optimizer with the current initial learning rate\n",
    "    optimizer = get_optimizer(\n",
    "        lr_schedule=LR_SCHEDULE,\n",
    "        init_lr=init_lr,\n",
    "        lr_decay_fact=LR_DECAY_FACT,\n",
    "        epochs=EPOCHS,\n",
    "        dat=X_train,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        opt=PRETRAIN_OPT,\n",
    "        momentum=MOMENTUM,\n",
    "        alpha=0,\n",
    "        large_lr_start=LARGE_LRSTART,\n",
    "        warmup=WARMUP\n",
    "    )\n",
    "    \n",
    "    # Compile model\n",
    "    hadamard_net.compile(\n",
    "        optimizer=optimizer,\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "    \n",
    "    print(hadamard_net.summary())\n",
    "    \n",
    "    # (Optional) List weight objects for each layer\n",
    "    #for layer in hadamard_net.layers:\n",
    "    #    print(f\"\\nLayer: {layer.name}\")\n",
    "    #    if layer.weights:\n",
    "    #        for weight in layer.weights:\n",
    "    #            print(f\"  Weight: {weight.name}, Shape: {weight.shape}\")\n",
    "    #            print(weight.numpy())\n",
    "    #    else:\n",
    "    #        print(\"  (No weights in this layer)\")\n",
    "    \n",
    "    # Define paths for saving pretraining weights and model\n",
    "    weights_name = f'{MODEL}_weights.h5'\n",
    "    full_model_path = f'{MODEL}_model.h5'\n",
    "    weights_path = os.path.join(RUN_PATH, weights_name)\n",
    "    \n",
    "    if os.path.exists(weights_path):\n",
    "        print('Existing pretraining weights found: load and evaluate')\n",
    "        hadamard_net.load_weights(weights_path)\n",
    "        pretrain_sparsity = -1\n",
    "        pretrain_compression_rate = -1\n",
    "        pretrain_misalignment = -1\n",
    "        pre_l2 = -1\n",
    "    else:\n",
    "        print(f'No existing pretraining weights found: use pretraining with depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "        pre_hist = hadamard_net.fit(\n",
    "            trainflow_aug,\n",
    "            steps_per_epoch=steps_per_ep,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            validation_data=val_flow,\n",
    "            validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "            epochs=EPOCHS,\n",
    "            callbacks=[custom_sparsity_callback, early_stopping, print_lr_cb, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        \n",
    "        # Optionally save the trained weights and model\n",
    "        # hadamard_net.save_weights(weights_path)\n",
    "        # hadamard_net.save(os.path.join(RUN_PATH, full_model_path))\n",
    "        if not os.path.exists(RUN_PATH):\n",
    "            os.makedirs(RUN_PATH)\n",
    "        \n",
    "        # Calculate and save metrics from the custom callback\n",
    "        pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "        pre_total, pre_epochs = process_sparsity_callback(hist=pre_hist, hadamard_cb=custom_sparsity_callback, lr_cb=print_lr_cb)\n",
    "        pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "        pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "        print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "        create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory_resnet_str.pdf')\n",
    "    \n",
    "    # Evaluate after pretraining\n",
    "    pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "    print('\\nTest loss:', pretrain_loss)\n",
    "    print('Test accuracy:', pretrain_acc)\n",
    "\n",
    "    # Duplicate the DataFrame\n",
    "    df = pre_epochs.copy()\n",
    "\n",
    "    # Keep only the rows corresponding to the highest epoch number\n",
    "    max_epoch = df['epoch'].max()\n",
    "    df = df[df['epoch'] == max_epoch]\n",
    "\n",
    "    # Remove rows with worb == 0 (i.e. biases rather than kernel filter weights)\n",
    "    df = df[df['worb'] != 0]\n",
    "\n",
    "    # Sort by the 'layer' column in ascending order and reset the index\n",
    "    df = df.sort_values(by='layer').reset_index(drop=True)\n",
    "    df['ind'] = range(1, len(df) + 1)\n",
    "\n",
    "    # Define a mapping function for ResNet18 based on new indices.\n",
    "    def get_filters(ind):\n",
    "        if ind == 1:\n",
    "            return 64  # Stem conv\n",
    "        elif ind in [2, 3, 4, 5]:\n",
    "            return 64  # Group 1 blocks\n",
    "        elif ind in [6, 7, 8, 9, 10]:\n",
    "            return 128  # Group 2 blocks\n",
    "        elif ind in [11, 12, 13, 14, 15]:\n",
    "            return 256  # Group 3 blocks\n",
    "        elif ind in [16, 17, 18, 19, 20]:\n",
    "            return 512  # Group 4 blocks\n",
    "        else:\n",
    "            return 512  # Fallback\n",
    "\n",
    "    df['filters'] = df['ind'].apply(get_filters)\n",
    "    df['remaining_filters'] = ((1 - df['sparsity']) * df['filters']).astype(int)\n",
    "    print(\"Processed sparsity data:\")\n",
    "    print(df)\n",
    "\n",
    "    # Assemble the new block_filters object for ResNet18.\n",
    "    # Structure: [stem_filter, [group1, group2, group3, group4]]\n",
    "    # Group 1: two blocks (each with 2 conv layers), Group 2: first block with 3 conv layers (projection) and second block with 2, etc.\n",
    "    block_filters = [\n",
    "        int(df.loc[df['ind'] == 1, 'remaining_filters'].iloc[0]),\n",
    "        [\n",
    "            [   # Group 1 (64 feature maps)\n",
    "                list(df.loc[df['ind'].isin([2, 3]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([4, 5]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 2 (128 feature maps)\n",
    "                list(df.loc[df['ind'].isin([6, 7, 8]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([9, 10]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 3 (256 feature maps)\n",
    "                list(df.loc[df['ind'].isin([11, 12, 13]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([14, 15]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 4 (512 feature maps)\n",
    "                list(df.loc[df['ind'].isin([16, 17, 18]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([19, 20]), 'remaining_filters'])\n",
    "            ]\n",
    "        ]\n",
    "    ]\n",
    "    \n",
    "    print(\"\\nNew block_filters object:\")\n",
    "    print(block_filters)\n",
    "\n",
    "    # Force consistency in projection blocks:\n",
    "    # For any block with 3 entries, set the projection filter (second element) equal to the main branch output (third element).\n",
    "    for g in range(len(block_filters[1])):\n",
    "        for b in range(len(block_filters[1][g])):\n",
    "            bf = block_filters[1][g][b]\n",
    "            if len(bf) == 3:\n",
    "                bf[1] = bf[2]\n",
    "                block_filters[1][g][b] = bf\n",
    "\n",
    "    print(\"\\nNew mod. block_filters object:\")\n",
    "    print(block_filters)\n",
    "\n",
    "    # Initialize the base (vanilla) ResNet18 model for FLOPs reference.\n",
    "    base_model = vanilla_resnet18_individual(n_classes=CLASS_NUM)\n",
    "    base_model.build(input_shape=(None, 32, 32, 3))\n",
    "    print(\"\\nBase model summary:\")\n",
    "    # base_model.summary()\n",
    "    base_flops, base_profile = get_flops_and_profile(base_model)\n",
    "    \n",
    "    if pretrain_sparsity < 1.0:\n",
    "\n",
    "        # Now, initialize the individual variant using our custom block_filters.\n",
    "        new_model = vanilla_resnet18_individual(\n",
    "            load_weights=False,\n",
    "            input_shape=(32, 32, 3),\n",
    "            n_classes=CLASS_NUM,\n",
    "            block_filters=block_filters\n",
    "        )\n",
    "\n",
    "        new_model.build(input_shape=(None, 32, 32, 3))\n",
    "        print(\"\\nNew model summary:\")\n",
    "        # new_model.summary()\n",
    "\n",
    "        total_flops, profile = get_flops_and_profile(new_model)\n",
    "        theo_speedup = base_flops / total_flops\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "    else:\n",
    "        total_flops = 0\n",
    "        theo_speedup = np.inf\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "\n",
    "\n",
    "    \n",
    "    # Print weight statistics for each layer\n",
    "    #for layer in hadamard_net.layers:\n",
    "    #    print(f\"Layer: {layer.name}\")\n",
    "    #    for weight in layer.weights:\n",
    "    #        abs_weights = tf.abs(weight)\n",
    "    #        min_val = tf.reduce_min(abs_weights)\n",
    "    #        max_val = tf.reduce_max(abs_weights)\n",
    "    #        print(f\"  Weight tensor: {weight.name}\")\n",
    "    #        print(f\"  Min absolute value: {min_val.numpy()}\")\n",
    "    #        print(f\"  Max absolute value: {max_val.numpy()}\")\n",
    "    \n",
    "    # --- Process, store, and save the pretraining results ---\n",
    "    pretrain_res_dict = {\n",
    "        'Pre Opt': PRETRAIN_OPT,\n",
    "        'Depth': int(DEPTH),\n",
    "        'Lambda': f'{LA:.2e}',\n",
    "        'Init Type': INIT_TYPE,\n",
    "        'Init LR': f'{init_lr:.1e}',\n",
    "        'LR Schedule': LR_SCHEDULE,\n",
    "        'Batch size': int(BATCH_SIZE),\n",
    "        'Pre Epochs': int(EPOCHS),\n",
    "        'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "        'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "        'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "        'Pre CR': f'{pretrain_compression_rate:.2f}',\n",
    "        'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "        'Pre L2': f'{pre_l2:4f}',\n",
    "        'Base FLOPs': int(base_flops),\n",
    "        'Reduced FLOPs': int(total_flops),\n",
    "        'Theoretical Speedup': f'{theo_speedup:4f}'\n",
    "    }\n",
    "    pretrain_res_df = pd.DataFrame([pretrain_res_dict])\n",
    "    \n",
    "    pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "    #if os.path.exists(pretrain_csv_file_path):\n",
    "    #    print('Pretrain results exist already, skipping saving of results')\n",
    "    #else:\n",
    "    pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "    print(f'\\nPretrain results saved to {pretrain_csv_file_path}')\n",
    "    \n",
    "    print(\"\\nPretraining Results:\")\n",
    "    print(pretrain_res_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.status.busy": "2025-03-30T06:10:06.326267Z",
     "iopub.status.idle": "2025-03-30T06:10:06.326485Z",
     "shell.execute_reply": "2025-03-30T06:10:06.326386Z",
     "shell.execute_reply.started": "2025-03-30T06:10:06.326376Z"
    }
   },
   "outputs": [],
   "source": [
    "# Cifar ResNet18 with filter sparsity D=3\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'str_resnet18'\n",
    "DEPTH = 3\n",
    "#LA =  2e-4\n",
    "#print(f'Starting run with depth = {DEPTH} and lambda={LA:.2e}')\n",
    "KERNEL_INITIALIZER  = KERNEL_INITIALIZER #tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = MULTFAC_INITIALIZER #tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "FACTORIZE_BIAS      = False\n",
    "init_lr = 0.3\n",
    "VERBOSE=1\n",
    "GRACE=20\n",
    "EPOCHS = 200\n",
    "\n",
    "lambdas_10 = [0,8e-5,2e-4,7e-4,8e-4,9e-4,1e-3,1.5e-3,2e-3,2.5e-3,3e-3,3.3e-3,3.6e-3,4.5e-3,5e-3,5.5e-3,6.5e-3,7.1e-3,8e-3,9e-3,1.2e-2,1.5e-2,2.05e-2,2.1e-2,0.2]\n",
    "\n",
    "# List of initial learning rates to iterate over\n",
    "#print(lambdas_10)\n",
    "\n",
    "for LA in lambdas_10:\n",
    "    print(f'\\nStarting run with learning rate: {init_lr:.2e} and depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "    \n",
    "    # Update run naming and saving paths with current learning rate\n",
    "    fmt_la = f\"{LA:.1e}\"\n",
    "    fmt_fine_la = f\"{FINE_LA:.1e}\"\n",
    "    RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{init_lr:.1e}-ftune-{FINETUNE_OPT}-flr{FINE_LR:.1e}-fla-{fmt_fine_la}-feps{FINETUNE_EPOCHS}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "    RUN_PATH = os.path.join(RESNET_FILE_PATH, RUN_NAME)\n",
    "    \n",
    "    # Set seed for reproducibility in each run\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    # Define callbacks\n",
    "    early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "    custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "    print_lr_cb = PrintLRCallback()\n",
    "    terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "    early_abort_cb = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "    \n",
    "    # Define the model with the given parameters\n",
    "    hadamard_net = hadamard_resnet18(\n",
    "        depth=DEPTH,\n",
    "        kernel_initializer=KERNEL_INITIALIZER,\n",
    "        multfac_initializer=MULTFAC_INITIALIZER,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        la=LA,\n",
    "        n_classes=CLASS_NUM,\n",
    "        use_bias=USE_BIAS\n",
    "    )\n",
    "    \n",
    "    # Create pretrain optimizer with the current initial learning rate\n",
    "    optimizer = get_optimizer(\n",
    "        lr_schedule=LR_SCHEDULE,\n",
    "        init_lr=init_lr,\n",
    "        lr_decay_fact=LR_DECAY_FACT,\n",
    "        epochs=EPOCHS,\n",
    "        dat=X_train,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        opt=PRETRAIN_OPT,\n",
    "        momentum=MOMENTUM,\n",
    "        alpha=0,\n",
    "        large_lr_start=LARGE_LRSTART,\n",
    "        warmup=WARMUP\n",
    "    )\n",
    "    \n",
    "    # Compile model\n",
    "    hadamard_net.compile(\n",
    "        optimizer=optimizer,\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "    \n",
    "    print(hadamard_net.summary())\n",
    "    \n",
    "    \n",
    "    # Define paths for saving pretraining weights and model\n",
    "    weights_name = f'{MODEL}_weights.h5'\n",
    "    full_model_path = f'{MODEL}_model.h5'\n",
    "    weights_path = os.path.join(RUN_PATH, weights_name)\n",
    "    \n",
    "    if os.path.exists(weights_path):\n",
    "        print('Existing pretraining weights found: load and evaluate')\n",
    "        hadamard_net.load_weights(weights_path)\n",
    "        pretrain_sparsity = -1\n",
    "        pretrain_compression_rate = -1\n",
    "        pretrain_misalignment = -1\n",
    "        pre_l2 = -1\n",
    "    else:\n",
    "        print(f'No existing pretraining weights found: use pretraining with depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "        pre_hist = hadamard_net.fit(\n",
    "            trainflow_aug,\n",
    "            steps_per_epoch=steps_per_ep,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            validation_data=val_flow,\n",
    "            validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "            epochs=EPOCHS,\n",
    "            callbacks=[custom_sparsity_callback, early_stopping, print_lr_cb, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        \n",
    "        # Optionally save the trained weights and model\n",
    "        # hadamard_net.save_weights(weights_path)\n",
    "        # hadamard_net.save(os.path.join(RUN_PATH, full_model_path))\n",
    "        if not os.path.exists(RUN_PATH):\n",
    "            os.makedirs(RUN_PATH)\n",
    "        \n",
    "        # Calculate and save metrics from the custom callback\n",
    "        pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "        pre_total, pre_epochs = process_sparsity_callback(hist=pre_hist, hadamard_cb=custom_sparsity_callback, lr_cb=print_lr_cb)\n",
    "        pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "        pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "        print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "        create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory_resnet_str.pdf')\n",
    "    \n",
    "    # Evaluate after pretraining\n",
    "    pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "    print('\\nTest loss:', pretrain_loss)\n",
    "    print('Test accuracy:', pretrain_acc)\n",
    "\n",
    "    # Duplicate the DataFrame\n",
    "    df = pre_epochs.copy()\n",
    "\n",
    "    # Keep only the rows corresponding to the highest epoch number\n",
    "    max_epoch = df['epoch'].max()\n",
    "    df = df[df['epoch'] == max_epoch]\n",
    "\n",
    "    # Remove rows with worb == 0 (i.e. biases rather than kernel filter weights)\n",
    "    df = df[df['worb'] != 0]\n",
    "\n",
    "    # Sort by the 'layer' column in ascending order and reset the index\n",
    "    df = df.sort_values(by='layer').reset_index(drop=True)\n",
    "    df['ind'] = range(1, len(df) + 1)\n",
    "\n",
    "    # Define a mapping function for ResNet18 based on new indices.\n",
    "    def get_filters(ind):\n",
    "        if ind == 1:\n",
    "            return 64  # Stem conv\n",
    "        elif ind in [2, 3, 4, 5]:\n",
    "            return 64  # Group 1 blocks\n",
    "        elif ind in [6, 7, 8, 9, 10]:\n",
    "            return 128  # Group 2 blocks\n",
    "        elif ind in [11, 12, 13, 14, 15]:\n",
    "            return 256  # Group 3 blocks\n",
    "        elif ind in [16, 17, 18, 19, 20]:\n",
    "            return 512  # Group 4 blocks\n",
    "        else:\n",
    "            return 512  # Fallback\n",
    "\n",
    "    df['filters'] = df['ind'].apply(get_filters)\n",
    "    df['remaining_filters'] = ((1 - df['sparsity']) * df['filters']).astype(int)\n",
    "    print(\"Processed sparsity data:\")\n",
    "    print(df)\n",
    "\n",
    "    # Assemble the new block_filters object for ResNet18.\n",
    "    # Structure: [stem_filter, [group1, group2, group3, group4]]\n",
    "    # Group 1: two blocks (each with 2 conv layers), Group 2: first block with 3 conv layers (projection) and second block with 2, etc.\n",
    "    block_filters = [\n",
    "        int(df.loc[df['ind'] == 1, 'remaining_filters'].iloc[0]),\n",
    "        [\n",
    "            [   # Group 1 (64 feature maps)\n",
    "                list(df.loc[df['ind'].isin([2, 3]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([4, 5]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 2 (128 feature maps)\n",
    "                list(df.loc[df['ind'].isin([6, 7, 8]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([9, 10]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 3 (256 feature maps)\n",
    "                list(df.loc[df['ind'].isin([11, 12, 13]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([14, 15]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 4 (512 feature maps)\n",
    "                list(df.loc[df['ind'].isin([16, 17, 18]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([19, 20]), 'remaining_filters'])\n",
    "            ]\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    # Force consistency in projection blocks:\n",
    "    # For any block with 3 entries, set the projection filter (second element) equal to the main branch output (third element).\n",
    "    for g in range(len(block_filters[1])):\n",
    "        for b in range(len(block_filters[1][g])):\n",
    "            bf = block_filters[1][g][b]\n",
    "            if len(bf) == 3:\n",
    "                bf[1] = bf[2]\n",
    "                block_filters[1][g][b] = bf\n",
    "\n",
    "    print(\"\\nNew block_filters object:\")\n",
    "    print(block_filters)\n",
    "\n",
    "    # Initialize the base (vanilla) ResNet18 model for FLOPs reference.\n",
    "    base_model = vanilla_resnet18_individual(n_classes=CLASS_NUM)\n",
    "    base_model.build(input_shape=(None, 32, 32, 3))\n",
    "    print(\"\\nBase model summary:\")\n",
    "    # base_model.summary()\n",
    "    base_flops, base_profile = get_flops_and_profile(base_model)\n",
    "\n",
    "    if pretrain_sparsity < 1.0:\n",
    "\n",
    "        # Now, initialize the individual variant using our custom block_filters.\n",
    "        new_model = vanilla_resnet18_individual(\n",
    "            load_weights=False,\n",
    "            input_shape=(32, 32, 3),\n",
    "            n_classes=CLASS_NUM,\n",
    "            block_filters=block_filters\n",
    "        )\n",
    "\n",
    "        new_model.build(input_shape=(None, 32, 32, 3))\n",
    "        print(\"\\nNew model summary:\")\n",
    "        # new_model.summary()\n",
    "\n",
    "        total_flops, profile = get_flops_and_profile(new_model)\n",
    "        theo_speedup = base_flops / total_flops\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "    else:\n",
    "        total_flops = 0\n",
    "        theo_speedup = np.inf\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "\n",
    "    \n",
    "    # --- Process, store, and save the pretraining results ---\n",
    "    pretrain_res_dict = {\n",
    "        'Pre Opt': PRETRAIN_OPT,\n",
    "        'Depth': int(DEPTH),\n",
    "        'Lambda': f'{LA:.2e}',\n",
    "        'Init Type': INIT_TYPE,\n",
    "        'Init LR': f'{init_lr:.1e}',\n",
    "        'LR Schedule': LR_SCHEDULE,\n",
    "        'Batch size': int(BATCH_SIZE),\n",
    "        'Pre Epochs': int(EPOCHS),\n",
    "        'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "        'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "        'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "        'Pre CR': f'{pretrain_compression_rate:.2f}',\n",
    "        'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "        'Pre L2': f'{pre_l2:4f}',\n",
    "        'Base FLOPs': int(base_flops),\n",
    "        'Reduced FLOPs': int(total_flops),\n",
    "        'Theoretical Speedup': f'{theo_speedup:4f}'\n",
    "    }\n",
    "    pretrain_res_df = pd.DataFrame([pretrain_res_dict])\n",
    "    \n",
    "    pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "    #if os.path.exists(pretrain_csv_file_path):\n",
    "    #    print('Pretrain results exist already, skipping saving of results')\n",
    "    #else:\n",
    "    pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "    print(f'\\nPretrain results saved to {pretrain_csv_file_path}')\n",
    "    \n",
    "    print(\"\\nPretraining Results:\")\n",
    "    print(pretrain_res_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cifar ResNet18 with filter sparsity D=4\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'str_resnet18'\n",
    "DEPTH = 4\n",
    "#LA =  2e-4\n",
    "#print(f'Starting run with depth = {DEPTH} and lambda={LA:.2e}')\n",
    "KERNEL_INITIALIZER  = KERNEL_INITIALIZER #tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = MULTFAC_INITIALIZER #tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "FACTORIZE_BIAS      = False\n",
    "init_lr = 0.4\n",
    "VERBOSE=1\n",
    "GRACE=20\n",
    "EPOCHS = 200\n",
    "\n",
    "lambdas_10 = [0,8e-5,2e-4,7e-4,8e-4,9e-4,1e-3,1.5e-3,2e-3,2.5e-3,3e-3,3.3e-3,3.6e-3,4.5e-3,5e-3,5.5e-3,6.5e-3,7.1e-3,8e-3,9e-3,1.2e-2,1.5e-2,2.05e-2,2.1e-2,0.2]\n",
    "\n",
    "# List of initial learning rates to iterate over\n",
    "#print(lambdas_10)\n",
    "\n",
    "for LA in lambdas_10:\n",
    "    print(f'\\nStarting run with learning rate: {init_lr:.2e} and depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "    \n",
    "    # Update run naming and saving paths with current learning rate\n",
    "    fmt_la = f\"{LA:.1e}\"\n",
    "    fmt_fine_la = f\"{FINE_LA:.1e}\"\n",
    "    RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{init_lr:.1e}-ftune-{FINETUNE_OPT}-flr{FINE_LR:.1e}-fla-{fmt_fine_la}-feps{FINETUNE_EPOCHS}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "    RUN_PATH = os.path.join(RESNET_FILE_PATH, RUN_NAME)\n",
    "    \n",
    "    # Set seed for reproducibility in each run\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    # Define callbacks\n",
    "    early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "    custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "    print_lr_cb = PrintLRCallback()\n",
    "    terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "    early_abort_cb = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "    \n",
    "    # Define the model with the given parameters\n",
    "    hadamard_net = hadamard_resnet18(\n",
    "        depth=DEPTH,\n",
    "        kernel_initializer=KERNEL_INITIALIZER,\n",
    "        multfac_initializer=MULTFAC_INITIALIZER,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        la=LA,\n",
    "        n_classes=CLASS_NUM,\n",
    "        use_bias=USE_BIAS\n",
    "    )\n",
    "    \n",
    "    # Create pretrain optimizer with the current initial learning rate\n",
    "    optimizer = get_optimizer(\n",
    "        lr_schedule=LR_SCHEDULE,\n",
    "        init_lr=init_lr,\n",
    "        lr_decay_fact=LR_DECAY_FACT,\n",
    "        epochs=EPOCHS,\n",
    "        dat=X_train,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        opt=PRETRAIN_OPT,\n",
    "        momentum=MOMENTUM,\n",
    "        alpha=0,\n",
    "        large_lr_start=LARGE_LRSTART,\n",
    "        warmup=WARMUP\n",
    "    )\n",
    "    \n",
    "    # Compile model\n",
    "    hadamard_net.compile(\n",
    "        optimizer=optimizer,\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "    \n",
    "    print(hadamard_net.summary())\n",
    "    \n",
    "    \n",
    "    # Define paths for saving pretraining weights and model\n",
    "    weights_name = f'{MODEL}_weights.h5'\n",
    "    full_model_path = f'{MODEL}_model.h5'\n",
    "    weights_path = os.path.join(RUN_PATH, weights_name)\n",
    "    \n",
    "    if os.path.exists(weights_path):\n",
    "        print('Existing pretraining weights found: load and evaluate')\n",
    "        hadamard_net.load_weights(weights_path)\n",
    "        pretrain_sparsity = -1\n",
    "        pretrain_compression_rate = -1\n",
    "        pretrain_misalignment = -1\n",
    "        pre_l2 = -1\n",
    "    else:\n",
    "        print(f'No existing pretraining weights found: use pretraining with depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "        pre_hist = hadamard_net.fit(\n",
    "            trainflow_aug,\n",
    "            steps_per_epoch=steps_per_ep,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            validation_data=val_flow,\n",
    "            validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "            epochs=EPOCHS,\n",
    "            callbacks=[custom_sparsity_callback, early_stopping, print_lr_cb, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        \n",
    "        # Optionally save the trained weights and model\n",
    "        # hadamard_net.save_weights(weights_path)\n",
    "        # hadamard_net.save(os.path.join(RUN_PATH, full_model_path))\n",
    "        if not os.path.exists(RUN_PATH):\n",
    "            os.makedirs(RUN_PATH)\n",
    "        \n",
    "        # Calculate and save metrics from the custom callback\n",
    "        pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "        pre_total, pre_epochs = process_sparsity_callback(hist=pre_hist, hadamard_cb=custom_sparsity_callback, lr_cb=print_lr_cb)\n",
    "        pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "        pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "        print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "        create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory_resnet_str.pdf')\n",
    "    \n",
    "    # Evaluate after pretraining\n",
    "    pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "    print('\\nTest loss:', pretrain_loss)\n",
    "    print('Test accuracy:', pretrain_acc)\n",
    "\n",
    "    # Duplicate the DataFrame\n",
    "    df = pre_epochs.copy()\n",
    "\n",
    "    # Keep only the rows corresponding to the highest epoch number\n",
    "    max_epoch = df['epoch'].max()\n",
    "    df = df[df['epoch'] == max_epoch]\n",
    "\n",
    "    # Remove rows with worb == 0 (i.e. biases rather than kernel filter weights)\n",
    "    df = df[df['worb'] != 0]\n",
    "\n",
    "    # Sort by the 'layer' column in ascending order and reset the index\n",
    "    df = df.sort_values(by='layer').reset_index(drop=True)\n",
    "    df['ind'] = range(1, len(df) + 1)\n",
    "\n",
    "    # Define a mapping function for ResNet18 based on new indices.\n",
    "    def get_filters(ind):\n",
    "        if ind == 1:\n",
    "            return 64  # Stem conv\n",
    "        elif ind in [2, 3, 4, 5]:\n",
    "            return 64  # Group 1 blocks\n",
    "        elif ind in [6, 7, 8, 9, 10]:\n",
    "            return 128  # Group 2 blocks\n",
    "        elif ind in [11, 12, 13, 14, 15]:\n",
    "            return 256  # Group 3 blocks\n",
    "        elif ind in [16, 17, 18, 19, 20]:\n",
    "            return 512  # Group 4 blocks\n",
    "        else:\n",
    "            return 512  # Fallback\n",
    "\n",
    "    df['filters'] = df['ind'].apply(get_filters)\n",
    "    df['remaining_filters'] = ((1 - df['sparsity']) * df['filters']).astype(int)\n",
    "    print(\"Processed sparsity data:\")\n",
    "    print(df)\n",
    "\n",
    "    # Assemble new block_filters object for ResNet18\n",
    "    # Structure: [stem_filter, [group1, group2, group3, group4]]\n",
    "    # Group 1: two blocks (each with 2 conv layers), Group 2: first block with 3 conv layers (projection) and second block with 2, etc.\n",
    "    block_filters = [\n",
    "        int(df.loc[df['ind'] == 1, 'remaining_filters'].iloc[0]),\n",
    "        [\n",
    "            [   # Group 1 (64 feature maps)\n",
    "                list(df.loc[df['ind'].isin([2, 3]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([4, 5]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 2 (128 feature maps)\n",
    "                list(df.loc[df['ind'].isin([6, 7, 8]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([9, 10]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 3 (256 feature maps)\n",
    "                list(df.loc[df['ind'].isin([11, 12, 13]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([14, 15]), 'remaining_filters'])\n",
    "            ],\n",
    "            [   # Group 4 (512 feature maps)\n",
    "                list(df.loc[df['ind'].isin([16, 17, 18]), 'remaining_filters']),\n",
    "                list(df.loc[df['ind'].isin([19, 20]), 'remaining_filters'])\n",
    "            ]\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    # Force consistency in projection blocks:\n",
    "    # For any block with 3 entries, set the projection filter (2nd element) to the main branch output (third element).\n",
    "    for g in range(len(block_filters[1])):\n",
    "        for b in range(len(block_filters[1][g])):\n",
    "            bf = block_filters[1][g][b]\n",
    "            if len(bf) == 3:\n",
    "                bf[1] = bf[2]\n",
    "                block_filters[1][g][b] = bf\n",
    "\n",
    "    print(\"\\nNew block_filters object:\")\n",
    "    print(block_filters)\n",
    "\n",
    "    # Initialize the base (vanilla) ResNet18 model for reference flops\n",
    "    base_model = vanilla_resnet18_individual(n_classes=CLASS_NUM)\n",
    "    base_model.build(input_shape=(None, 32, 32, 3))\n",
    "    print(\"\\nBase model summary:\")\n",
    "    # base_model.summary()\n",
    "    base_flops, base_profile = get_flops_and_profile(base_model)\n",
    "\n",
    "    if pretrain_sparsity < 1.0:\n",
    "\n",
    "        # Now, initialize individual model variant using our custom block_filters.\n",
    "        new_model = vanilla_resnet18_individual(\n",
    "            load_weights=False,\n",
    "            input_shape=(32, 32, 3),\n",
    "            n_classes=CLASS_NUM,\n",
    "            block_filters=block_filters\n",
    "        )\n",
    "\n",
    "        new_model.build(input_shape=(None, 32, 32, 3))\n",
    "        print(\"\\nNew model summary:\")\n",
    "        # new_model.summary()\n",
    "\n",
    "        total_flops, profile = get_flops_and_profile(new_model)\n",
    "        theo_speedup = base_flops / total_flops\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "    else:\n",
    "        total_flops = 0\n",
    "        theo_speedup = np.inf\n",
    "        print(\"\\nTheoretical speedup:\")\n",
    "        print(theo_speedup)\n",
    "\n",
    "\n",
    "    \n",
    "    # --- Process, store, and save the pretraining results ---\n",
    "    pretrain_res_dict = {\n",
    "        'Pre Opt': PRETRAIN_OPT,\n",
    "        'Depth': int(DEPTH),\n",
    "        'Lambda': f'{LA:.2e}',\n",
    "        'Init Type': INIT_TYPE,\n",
    "        'Init LR': f'{init_lr:.1e}',\n",
    "        'LR Schedule': LR_SCHEDULE,\n",
    "        'Batch size': int(BATCH_SIZE),\n",
    "        'Pre Epochs': int(EPOCHS),\n",
    "        'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "        'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "        'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "        'Pre CR': f'{pretrain_compression_rate:.2f}',\n",
    "        'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "        'Pre L2': f'{pre_l2:4f}',\n",
    "        'Base FLOPs': int(base_flops),\n",
    "        'Reduced FLOPs': int(total_flops),\n",
    "        'Theoretical Speedup': f'{theo_speedup:4f}'\n",
    "    }\n",
    "    pretrain_res_df = pd.DataFrame([pretrain_res_dict])\n",
    "    \n",
    "    pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "    #if os.path.exists(pretrain_csv_file_path):\n",
    "    #    print('Pretrain results exist already, skipping saving of results')\n",
    "    #else:\n",
    "    pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "    print(f'\\nPretrain results saved to {pretrain_csv_file_path}')\n",
    "    \n",
    "    print(\"\\nPretraining Results:\")\n",
    "    print(pretrain_res_df)\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
