{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-22T09:24:38.043759Z",
     "iopub.status.busy": "2024-08-22T09:24:38.043068Z",
     "iopub.status.idle": "2024-08-22T09:24:50.530960Z",
     "shell.execute_reply": "2024-08-22T09:24:50.530217Z",
     "shell.execute_reply.started": "2024-08-22T09:24:38.043724Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import WideResNet, small_hadamard_resnet18, hadamard_resnet18, hadamard_resnet20,\\\n",
    "                         hadamard_resnet1001, hadamard_WRN_16_4, hadamard_WRN_40_4, hadamard_WRN_16_8,\\\n",
    "                         hadamard_WRN_28_10, hadamard_resnet56 \n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                      threshold_model_weights, flatten_and_filter_weights, create_and_save_trajectory_plot\n",
    "\n",
    "################################################################################\n",
    "# Define model architecture and training parameters \n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 3\n",
    "LA                 = 1e-4\n",
    "INIT_TYPE          ='equivar' #vanilla, equivar, root, ones\n",
    "MINPROD            = 3e-3\n",
    "INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "#INIT               = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 250 #200\n",
    "INIT_LR            = 0.5 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Finetuning\n",
    "DO_FINETUNING      = True\n",
    "FINETUNE_OPT       = 'sgd' # sgd, adam\n",
    "FINE_SCHEDULE      = 'cosine' # piecewise, constant, cosine\n",
    "FINETUNE_EPOCHS    = 50     if FINETUNE_OPT == 'sgd' else 20\n",
    "FINE_LA            = 0.05 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "FINE_LR            = 0.4 * INIT_LR   if FINETUNE_OPT == 'sgd' else 2e-4\n",
    "FINETUNE_ALPHA     = 1e-5   if FINETUNE_OPT == 'sgd' else 1e-5\n",
    "\n",
    "# Input settings\n",
    "CLASS_NUM          = 10\n",
    "IMG_ROWS, IMG_COLS = 32, 32\n",
    "IMG_CHANNELS       = 3\n",
    "\n",
    "# WideResNet settings\n",
    "DEP                = 16\n",
    "WIDE               = 8\n",
    "IN_FILTERS         = 16\n",
    "\n",
    "# Misc\n",
    "PAT                = 1000\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 5\n",
    "FINE_GRACE         = 5 #10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "compression_grid   = [10,50,75,100,200,500,700,800,1000,\n",
    "                      1250,1500,1750,2000,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,\n",
    "                      15000,20000,25000,30000,40000,50000,75000,100000]\n",
    "lambdas_all         = [0, 1e-6, 5e-6,  #0-2\n",
    "                       1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5, #3-11\n",
    "                       1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #12-20\n",
    "                       1e-3, 2e-3, 5e-3] #21-23\n",
    "    \n",
    "lambdas              = lambdas_all\n",
    "\n",
    "# Deirectories and saving paths\n",
    "RESNET_FILE_PATH = './results_resnet/resnet18_cifar10_pruning/'\n",
    "\n",
    "print('Defining configs successful!')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-08-22T09:25:24.899363Z",
     "iopub.status.busy": "2024-08-22T09:25:24.898383Z",
     "iopub.status.idle": "2024-08-22T09:25:46.338144Z",
     "shell.execute_reply": "2024-08-22T09:25:46.337651Z",
     "shell.execute_reply.started": "2024-08-22T09:25:24.899337Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Load data\n",
    "(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1, 2))  \n",
    "    stds = np.std(dataset, axis=(0, 1, 2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.1,shuffle = True)\n",
    "\n",
    "# Color preprocessing\n",
    "X_train, X_test, X_val = color_preprocessing(X_train, X_test, X_val)\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Pre-process labels\n",
    "encoder = OneHotEncoder()\n",
    "encoder.fit(Y_train)\n",
    "Y_train = encoder.transform(Y_train).toarray()\n",
    "Y_test = encoder.transform(Y_test).toarray()\n",
    "Y_val =  encoder.transform(Y_val).toarray()\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n",
    "\n",
    "# Define augmentation and data generators and fit to data\n",
    "print('Using real-time data augmentation.')\n",
    "aug_datagen = ImageDataGenerator(horizontal_flip=True,width_shift_range=0.125,height_shift_range=0.125,\n",
    "                                 rotation_range=15,fill_mode='reflect')\n",
    "aug_datagen.fit(X_train)\n",
    "trainflow_aug = aug_datagen.flow(X_train, Y_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "val_datagen = ImageDataGenerator()\n",
    "val_datagen.fit(X_val)\n",
    "val_flow = val_datagen.flow(X_val, Y_val, batch_size=BATCH_SIZE)\n",
    "steps_per_ep = len(X_train) / BATCH_SIZE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-08-21T21:21:21.246245Z",
     "iopub.status.busy": "2024-08-21T21:21:21.245974Z",
     "iopub.status.idle": "2024-08-21T21:21:28.543596Z",
     "shell.execute_reply": "2024-08-21T21:21:28.542891Z",
     "shell.execute_reply.started": "2024-08-21T21:21:21.246226Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [],
   "source": [
    "# ResNet-18 #1\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'resnet18' # 'wrn16-8'\n",
    "DEPTH = DEPTH\n",
    "LA = lambdas[0] #LA\n",
    "print(f'Starting run with lambda={LA:.2e}')\n",
    "FINE_LA = 1 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "INIT_TYPE = 'equivar'\n",
    "INIT_LR = INIT_LR\n",
    "INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) # DWF INIT\n",
    "#INIT = tf.keras.initializers.HeNormal()\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Deirectories and saving paths\n",
    "fmt_la = f\"{LA:.1e}\"\n",
    "fmt_fine_la = f\"{FINE_LA:.1e}\"\n",
    "RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-\\\n",
    "ftune-{FINETUNE_OPT}-flr{FINE_LR:.1e}-fla-{fmt_fine_la}-feps{FINETUNE_EPOCHS}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "RUN_PATH = os.path.join(RESNET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "# Create dir\n",
    "if not os.path.exists(RUN_PATH):\n",
    "    os.makedirs(RUN_PATH)\n",
    "    \n",
    "################################################################################\n",
    "\n",
    "# Set seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Callbacks\n",
    "early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "print_lr_cb = PrintLRCallback()\n",
    "terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "# Define model\n",
    "#hadamard_net = WideResNet(dep=DEP, k=WIDE, input_shape = (IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes = CLASS_NUM, depth=DEPTH,\\\n",
    "#                             init_type=INIT_TYPE, init = INIT, la=LA, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "\n",
    "# Pretrain optimizer\n",
    "optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                          large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "# Compile model\n",
    "hadamard_net.compile(optimizer=optimizer,\n",
    "               loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "               metrics=['accuracy'])\n",
    "\n",
    "print(hadamard_net.summary())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-08-21T21:21:34.068402Z",
     "iopub.status.busy": "2024-08-21T21:21:34.067543Z",
     "iopub.status.idle": "2024-08-21T22:39:28.275160Z",
     "shell.execute_reply": "2024-08-21T22:39:28.274487Z",
     "shell.execute_reply.started": "2024-08-21T21:21:34.068374Z"
    },
    "id": "tl0LSlRbWRPE",
    "outputId": "45be0d83-54f3-488a-9580-92fdf717cf61"
   },
   "outputs": [],
   "source": [
    "## Pre-training\n",
    "\n",
    "## Check if model exists\n",
    "weights_name = f'{MODEL}_weights.h5'\n",
    "full_model_path = f'{MODEL}_model.h5'\n",
    "weights_path = os.path.join(RUN_PATH, weights_name)\n",
    "\n",
    "if os.path.exists(weights_path):\n",
    "    print('Existing pretraining weights found: load and evaluate')\n",
    "    hadamard_net.load_weights(weights_path)\n",
    "    pretrain_sparsity = -1\n",
    "    pretrain_compression_rate = -1\n",
    "    pretrain_misalignment = -1\n",
    "    pre_l2 = -1\n",
    "else:\n",
    "    print(f'No existing pretraining weights found: starting pretraining with lambda={LA:.2e}')\n",
    "    pre_hist = hadamard_net.fit(trainflow_aug,\n",
    "                    steps_per_epoch=steps_per_ep,\n",
    "                    batch_size = BATCH_SIZE,\n",
    "                    validation_data=val_flow,\n",
    "                    validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "                    epochs=EPOCHS,\n",
    "                    callbacks=[early_stopping, custom_sparsity_callback, print_lr_cb, terminate_nan_cb, early_abort_cb])\n",
    "\n",
    "    # Save newly trained model weights\n",
    "    hadamard_net.save_weights(weights_path)\n",
    "    #hadamard_net.save(os.path.joint(RUN_PATH, full_model_path))\n",
    "    \n",
    "    # Calculate sparsity and compression rate at last epoch\n",
    "    pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "    \n",
    "    # Save pretraining trajectory\n",
    "    pre_total, pre_epochs = process_sparsity_callback(hist = pre_hist, hadamard_cb = custom_sparsity_callback, lr_cb = print_lr_cb)\n",
    "    pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "    pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "    print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "    \n",
    "    # Construct and save plot training metrics trajectories\n",
    "    create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory.pdf')\n",
    "\n",
    "# Evaluate after pretraining\n",
    "pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "print('\\nTest loss', pretrain_loss)\n",
    "print('Test accuracy', pretrain_acc)\n",
    "print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "print('Misalignment (pretrain)', pretrain_misalignment)\n",
    "print('Total L2 norm (pretrain)', pre_l2)\n",
    "\n",
    "# Initialize df to store results\n",
    "pretrain_res_df = pd.DataFrame(columns=['Pre Opt', 'Depth', 'Lambda', 'Init Type', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                        'Pre Epochs', 'Pre Loss','Pre Acc', 'Pre Sparsity', 'Pre CR', 'Pre Misalign', 'Pre L2'])\n",
    "# Store formatted results in dict\n",
    "pretrain_res_dict = {\n",
    "  'Pre Opt': PRETRAIN_OPT,'Depth': int(DEPTH),'Lambda': f'{LA:.2e}','Init Type': INIT_TYPE,'Init LR': f'{INIT_LR:.2e}',\n",
    "  'LR Schedule': LR_SCHEDULE,'Batch size': int(BATCH_SIZE), 'Pre Epochs': int(EPOCHS),'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "  'Pre Acc': f'{pretrain_acc * 100:.4f}%','Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "  'Pre CR': f'{pretrain_compression_rate:.2f}', 'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "  'Pre L2': f'{pre_l2:4f}'\n",
    "  }\n",
    "\n",
    "# Append results to df\n",
    "pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "# Save df to CSV if new pretrain \n",
    "pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "if os.path.exists(pretrain_csv_file_path):\n",
    "    print('Pretrain results exist already, skipping saving of results')\n",
    "else:\n",
    "    pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "    print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "print(\"\\nPretraining Results:\")\n",
    "print(pretrain_res_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T22:39:28.276858Z",
     "iopub.status.busy": "2024-08-21T22:39:28.276635Z",
     "iopub.status.idle": "2024-08-21T22:39:35.991628Z",
     "shell.execute_reply": "2024-08-21T22:39:35.991131Z",
     "shell.execute_reply.started": "2024-08-21T22:39:28.276842Z"
    }
   },
   "outputs": [],
   "source": [
    "# Compute thresholds for desired compression rates using reconstructed weights\n",
    "\n",
    "# Get pretraining compression rate and apply pruning_cr_factors\n",
    "\n",
    "# Clone pretrained model\n",
    "#model_clone = WideResNet(dep=DEP, k=WIDE, input_shape = (IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes = CLASS_NUM, depth=DEPTH,\\\n",
    "#                         init_type=INIT_TYPE, init = INIT, la=FINE_LA, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "model_clone = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\n",
    "                                init_type=INIT_TYPE, init=INIT, la=FINE_LA,\\\n",
    "                                input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "model_clone.build(input_shape=(None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "model_clone.load_weights(weights_path)\n",
    "\n",
    "# Calculate intrinsic sparsity/compression rate/collapsed weights\n",
    "model_clone, reconstructed_weights, pre_cr_clone = threshold_model_weights(model_clone, threshold = np.finfo(np.float32).eps, \n",
    "                                                    mode='model')\n",
    "\n",
    "# Compute range of CR values we further prune the pretrained model to\n",
    "#pruning_cr_vals = pre_cr_clone * pruning_cr_factors\n",
    "pruning_cr_vals = [cr for cr in compression_grid if cr > pre_cr_clone][:6]\n",
    "print(pruning_cr_vals)\n",
    "\n",
    "# Compute thresholds corresponding to the calculated CR values\n",
    "\n",
    "flattened_weights = np.array(flatten_and_filter_weights(reconstructed_weights))\n",
    "thresholds = compute_thresholds(pruning_cr_vals, flattened_weights)\n",
    "\n",
    "messages = [f\"- CR {cr:.2f}, Threshold: {th:.2e}\" for cr, th in zip(pruning_cr_vals, thresholds)]\n",
    "message = \"The computed thresholds are:\\n\" + \"\\n\".join(messages)\n",
    "\n",
    "print(message)\n",
    "\n",
    "# Compute quantiles of reconstructed weights\n",
    "quantiles = np.percentile(flattened_weights, [1, 80, 90, 95, 99, 99.5, 99.9, 99.95, 99.99])\n",
    "minabs    = np.min(flattened_weights)\n",
    "maxabs    = np.max(flattened_weights)\n",
    "\n",
    "# Formatting and printing quantile results\n",
    "quantile_values = {\n",
    "    \"Min. abs. val.\": minabs,\n",
    "    \"Q1\": quantiles[0],\"Q80\": quantiles[1],\"Q90\": quantiles[2],  \"Q95\": quantiles[3],\"Q99\": quantiles[4],\n",
    "    \"Q99.5\": quantiles[5],\"Q99.9\": quantiles[6],\"Q99.95\": quantiles[7],\"Q99.99\": quantiles[8],\n",
    "    \"Max. abs. val.\": maxabs\n",
    "}\n",
    "\n",
    "for quantile, value in quantile_values.items():\n",
    "    print(f\"{quantile}: {value:.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-08-21T22:39:35.992662Z",
     "iopub.status.busy": "2024-08-21T22:39:35.992458Z",
     "iopub.status.idle": "2024-08-21T23:59:28.591060Z",
     "shell.execute_reply": "2024-08-21T23:59:28.590405Z",
     "shell.execute_reply.started": "2024-08-21T22:39:35.992612Z"
    },
    "id": "qU9Nnv4NfuPu",
    "outputId": "29059632-73de-4818-9837-1eb093e9594c"
   },
   "outputs": [],
   "source": [
    "# Pruning + Finetuning\n",
    "\n",
    "if DO_FINETUNING:\n",
    "    # Initialize a DataFrame to store the results\n",
    "    results_df = pd.DataFrame(columns=['Pre Opt','Depth', 'Lambda', 'Init Type', 'Init LR', 'LR Schedule', 'Batch size', \\\n",
    "                                       'Pre Epochs', 'Pre Loss','Pre Acc', 'Pre Sparsity', 'Pre CR', 'Pre Misalignment', \\\n",
    "                                       'Pre L2', 'Prune Threshold', 'Pruned Acc', 'Fine-tune Opt', 'Fine-tune Epochs',\\\n",
    "                                       'Fine-tune LR', 'Fine-tune LA', 'Fine-tune Alpha','Fine-tune Loss', 'Fine-tune Acc',\\\n",
    "                                       'Fine-tune CR', 'Fine-tune Sparsity', 'Fine-tune Misalignment', 'Fine-tune L2'])\n",
    "\n",
    "    # loop over thresholds: prune+finetune\n",
    "    for ind, threshold in enumerate(thresholds):\n",
    "\n",
    "        print(f'\\nStarting with threshold {ind}: {threshold:.2e}')\n",
    "        # Set seed\n",
    "        np.random.seed(SEED)\n",
    "        random.seed(SEED)\n",
    "        tf.random.set_seed(SEED)\n",
    "\n",
    "        ## Instantiate pretrained model and finetuining optimizer\n",
    "\n",
    "        #model_clone = WideResNet(dep=DEP, k=WIDE, input_shape = (IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes = CLASS_NUM, depth=DEPTH,\\\n",
    "        #                         init_type=INIT_TYPE, init = INIT, la=FINE_LA, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "        model_clone = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\n",
    "                                        init_type=INIT_TYPE, init=INIT, la=FINE_LA,\\\n",
    "                                        input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "        model_clone.build(input_shape=(None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "        model_clone.load_weights(weights_path)\n",
    "\n",
    "        # Prune pretrained model\n",
    "        model_clone, reconstructed_pruned_weights, pruned_cr_clone = threshold_model_weights(model_clone, \n",
    "                                                                                             threshold=threshold, \n",
    "                                                                                             mode='model')\n",
    "        print(f'CR of pruned pretrained model is {pruned_cr_clone:.2f}')\n",
    "\n",
    "        # Optimizer and callbacks for finetuning\n",
    "        fine_optimizer = get_optimizer(lr_schedule=FINE_SCHEDULE, init_lr=FINE_LR, lr_decay_fact=LR_DECAY_FACT,\\\n",
    "                                            epochs=FINETUNE_EPOCHS, dat=X_train, batch_size=BATCH_SIZE, opt=FINETUNE_OPT,\\\n",
    "                                            momentum=MOMENTUM, alpha=FINETUNE_ALPHA, warmup=False)\n",
    "        fine_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "        fine_lr_cb = PrintLRCallback()\n",
    "        fine_abort_cb = TerminateBadRuns(grace=FINE_GRACE, minacc = MINACC)\n",
    "        terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "\n",
    "        # Compile model with the new optimizer\n",
    "        model_clone.compile(optimizer=fine_optimizer,\n",
    "                            loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "                            metrics=['accuracy'])\n",
    "\n",
    "        #model_clone.set_weights(new_weights)\n",
    "\n",
    "        pruned_loss, pruned_acc = model_clone.evaluate(X_test, Y_test)\n",
    "        print(f'\\nTest accuracy of pruned pretrained model is {pruned_acc * 100:.4f}% ')\n",
    "\n",
    "        ## Fine-tune pruned the model\n",
    "        fine_hist = model_clone.fit(trainflow_aug,\n",
    "                        steps_per_epoch=steps_per_ep,\n",
    "                        batch_size=BATCH_SIZE,\n",
    "                        validation_data=val_flow,\n",
    "                        validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "                        epochs=FINETUNE_EPOCHS,\n",
    "                        callbacks=[fine_sparsity_callback, fine_lr_cb, fine_abort_cb, terminate_nan_cb])\n",
    "\n",
    "        # Save finetuning trajectory\n",
    "        fine_total, fine_epochs = process_sparsity_callback(hist = fine_hist, hadamard_cb = fine_sparsity_callback,lr_cb = fine_lr_cb)\n",
    "        fine_total.to_csv(os.path.join(RUN_PATH, f'fine_cb_total_{ind}.csv'), index=False)\n",
    "        fine_epochs.to_csv(os.path.join(RUN_PATH, f'fine_cb_epochs_{ind}.csv'), index=False)\n",
    "        print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "\n",
    "        # Evaluate fine-tuned model\n",
    "        finetune_loss, finetune_acc = model_clone.evaluate(X_test, Y_test)\n",
    "        finetune_sparsity, finetune_compression_rate, finetune_misalignment, fine_l2 = fine_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "\n",
    "        # Construct and save plot finetuning metrics trajectories\n",
    "        create_and_save_trajectory_plot(cb_total=fine_total, run_path=RUN_PATH, max_loss=10, \n",
    "                                        out_name=f'fine_trajectory_{ind}.pdf',show=False)\n",
    "\n",
    "        # Store formatted finetuning results in dictionary\n",
    "        results_dict = {\n",
    "          'Pre Opt': PRETRAIN_OPT,'Depth': int(DEPTH),'Lambda': f'{LA:.2e}','Init Type': INIT_TYPE,'Init LR': f'{INIT_LR:.2e}',\n",
    "          'LR Schedule': LR_SCHEDULE,'Batch size': BATCH_SIZE, 'Pre Epochs': int(EPOCHS),'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "          'Pre Acc': f'{pretrain_acc * 100:.4f}%','Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "          'Pre CR': f'{pretrain_compression_rate:.2f}', 'Pre Misalignment': f'{pretrain_misalignment:.4f}',\n",
    "          'Pre L2': f'{pre_l2:.4f}', 'Prune Threshold': f'{threshold:.2e}', 'Pruned Acc': f'{pruned_acc * 100:.4f}%',\n",
    "          'Fine-tune Opt': FINETUNE_OPT,'Fine-tune Epochs': int(FINETUNE_EPOCHS),'Fine-tune LR': f'{FINE_LR:.2e}',\n",
    "          'Fine-tune LA': f'{FINE_LA:.2e}','Fine-tune Alpha': f'{FINETUNE_ALPHA:.2e}','Fine-tune Loss': f'{finetune_loss:.3f}',\n",
    "          'Fine-tune Acc': f'{finetune_acc * 100:.4f}%','Fine-tune CR': f'{finetune_compression_rate:.2f}',\n",
    "          'Fine-tune Sparsity': f'{finetune_sparsity * 100:.4f}%', 'Fine-tune Misalignment': f'{finetune_misalignment:.2f}',\n",
    "          'Fine-tune L2': f'{fine_l2:.4f}'\n",
    "          }\n",
    "\n",
    "        # Append finetuning iteration results to df\n",
    "        results_df = pd.concat([results_df, pd.DataFrame([results_dict])], ignore_index=True)\n",
    "        print(f'\\nThreshold: {threshold} - Results: {results_dict}')\n",
    "\n",
    "    # Save final finetune results df to CSV\n",
    "    csv_file_path = os.path.join(RUN_PATH, f'finetuning_{MODEL}.csv')\n",
    "    results_df.to_csv(csv_file_path, index=False)\n",
    "\n",
    "    print(f'Results saved to {csv_file_path}')\n",
    "\n",
    "    print(\"\\nFinal Fine-tuning Results\")\n",
    "    print(results_df)\n",
    "else: print('\\n Skipping finetuning as specified')\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
