{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-23T08:45:08.608633Z",
     "iopub.status.busy": "2024-06-23T08:45:08.608056Z",
     "iopub.status.idle": "2024-06-23T08:45:11.603019Z",
     "shell.execute_reply": "2024-06-23T08:45:11.602485Z",
     "shell.execute_reply.started": "2024-06-23T08:45:08.608609Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import WideResNet, hadamard_resnet18, hadamard_resnet20, hadamard_resnet34, hadamard_vgg16, hadamard_vgg19\n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                         create_and_save_trajectory_plot\n",
    "\n",
    "################################################################################\n",
    "# Define model architecture and training parameters (all common resnets/WRNs available)\n",
    "\n",
    "DATA               = 'cifar100' # 'cifar100' 'cifar10'\n",
    "NET                = 'vgg19' # 'wrn168' 'resnet34' 'resnet18' 'vgg16' 'vgg19'\n",
    "#DEPTH              = 4\n",
    "#WARMUP             = True\n",
    "\n",
    "# Hadamard\n",
    "#LA                 = 5e-4\n",
    "INIT_TYPE          ='equivar' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones()\n",
    "USE_BIAS           = True # use only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 128 #256\n",
    "EPOCHS             = 200 #200 #250\n",
    "WARMUP_EPS = 2 \n",
    "WARMUP_INIT_LR = 0.005\n",
    "\n",
    "    \n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1    # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Input settings\n",
    "if DATA == 'cifar10':\n",
    "    CLASS_NUM      = 10\n",
    "elif DATA == 'cifar100':\n",
    "    CLASS_NUM      = 100\n",
    "else:\n",
    "    CLASS_NUM      = 10\n",
    "\n",
    "IMG_ROWS, IMG_COLS = 32, 32\n",
    "IMG_CHANNELS       = 3\n",
    "\n",
    "# WideResNet settings if applicable\n",
    "DEP                = 16\n",
    "WIDE               = 8\n",
    "IN_FILTERS         = 16\n",
    "\n",
    "# Misc\n",
    "#PAT                = 1000\n",
    "#RESTORE_WEIGHTS    = False\n",
    "GRACE              = 8\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 1\n",
    "#lambdas_all         = [0, 1e-6, 5e-6,  #0-2\n",
    "#                       1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5, #3-11\n",
    "#                       1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #12-20\n",
    "#                       1e-3, 2e-3, 5e-3, 7e-3, 1e-2, 1e-1] #21-26\n",
    "lambdas_all         = [0, 1e-6, 5e-6, 8e-6,  #0-3\n",
    "                       1e-5, 1.5e-5, 2e-5, 2.5e-5, 3e-5, 3.5e-5, 4e-5, 4.5e-5, 5e-5, 5.5e-5, 6e-5, 6.5e-5, 7e-5, 7.5e-6, #4-17\n",
    "                       8e-5, 8.5e-5, 9e-5, 9.5e-5, 1e-4, 1.5e-4, 2e-4, 2.5e-4, 3e-4, 3.5e-4, 4e-4, 4.5e-4, #18-29\n",
    "                       5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #30-34\n",
    "                       1e-3, 2e-3, 5e-3, 7e-3, 1e-2, 1e-1] #35-40\n",
    "\n",
    "\n",
    "print('Defining configs successful!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-06-23T08:45:18.806927Z",
     "iopub.status.busy": "2024-06-23T08:45:18.806421Z",
     "iopub.status.idle": "2024-06-23T08:45:22.190764Z",
     "shell.execute_reply": "2024-06-23T08:45:22.190294Z",
     "shell.execute_reply.started": "2024-06-23T08:45:18.806901Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "print(f\"Data set: {DATA}\")\n",
    "\n",
    "# Load data\n",
    "if DATA == 'cifar10':\n",
    "    (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()\n",
    "elif DATA == 'cifar100':\n",
    "    (X_train, Y_train), (X_test, Y_test) = cifar100.load_data()\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1, 2))  \n",
    "    stds = np.std(dataset, axis=(0, 1, 2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "# mean = [125.3, 123.0, 113.9], std  = [63.0,  62.1,  66.7] for cifar10\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.05,shuffle = True)\n",
    "\n",
    "# Color preprocessing\n",
    "X_train, X_test, X_val = color_preprocessing(X_train, X_test, X_val)\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Pre-process labels\n",
    "encoder = OneHotEncoder()\n",
    "encoder.fit(Y_train)\n",
    "Y_train = encoder.transform(Y_train).toarray()\n",
    "Y_test = encoder.transform(Y_test).toarray()\n",
    "Y_val =  encoder.transform(Y_val).toarray()\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n",
    "\n",
    "# Define augmentation and data generators and fit to data\n",
    "print('Using real-time data augmentation.')\n",
    "aug_datagen = ImageDataGenerator(horizontal_flip=True,width_shift_range=0.125,height_shift_range=0.125,\n",
    "                                 rotation_range=15,fill_mode='reflect')\n",
    "aug_datagen.fit(X_train)\n",
    "trainflow_aug = aug_datagen.flow(X_train, Y_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "val_datagen = ImageDataGenerator()\n",
    "val_datagen.fit(X_val)\n",
    "val_flow = val_datagen.flow(X_val, Y_val, batch_size=BATCH_SIZE)\n",
    "steps_per_ep = len(X_train) / BATCH_SIZE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2024-06-23T08:45:23.693018Z",
     "iopub.status.busy": "2024-06-23T08:45:23.692725Z",
     "iopub.status.idle": "2024-06-23T09:07:28.775488Z",
     "shell.execute_reply": "2024-06-23T09:07:28.775005Z",
     "shell.execute_reply.started": "2024-06-23T08:45:23.692989Z"
    },
    "id": "6yy9KHxyWOBY",
    "outputId": "7cfef7ad-a3ae-49a2-92d7-820a174079d9"
   },
   "outputs": [],
   "source": [
    "WARMUP_VALUES = [False, True]\n",
    "DEPTH_VALUES = [2, 3, 4]\n",
    "NUM_LRS = 4\n",
    "\n",
    "for WARMUP in WARMUP_VALUES:\n",
    "    warmup_key = 'warmup' if WARMUP else 'nowarmup'\n",
    "    print(f\"Using {warmup_key} in this run!\")\n",
    "    \n",
    "    for lr_index in range(NUM_LRS):\n",
    "        print(f\"Using lr index number {lr_index+1} / {len(DEPTH_VALUES)}!\")\n",
    "        \n",
    "        for dep_index in range(len(DEPTH_VALUES)):\n",
    "            DEPTH = DEPTH_VALUES[dep_index]\n",
    "            print(f\"Using depth={DEPTH} with dep_index={dep_index}!\")\n",
    "            if DEPTH == 2:\n",
    "                dep_lr_range = [0.15, 0.2, 0.3, 0.5]\n",
    "            elif DEPTH == 3:\n",
    "                dep_lr_range = [0.3, 0.4, 0.6, 0.8]\n",
    "            elif DEPTH == 4:\n",
    "                dep_lr_range = [0.6, 0.8, 1.1, 1.5]\n",
    "            else:\n",
    "                DEPTH = 1\n",
    "                dep_lr_range = [0.05, 0.1, 0.12, 0.15]\n",
    "                print(\"Incorrect depth specified, check code!\")\n",
    "                \n",
    "            print(f\"Learning rate range is {dep_lr_range}\")\n",
    "            INIT_LR = dep_lr_range[lr_index]\n",
    "            print(f\"Current learning rate for depth {DEPTH} is {INIT_LR:.1e}!\")\n",
    "            \n",
    "            # Create directory and saving path\n",
    "            NET_FILE_PATH = f\"./results_{NET}_{DATA}/depth{DEPTH}/{warmup_key}/lr{INIT_LR:.1e}\"\n",
    "            if not os.path.exists(NET_FILE_PATH):\n",
    "                os.makedirs(NET_FILE_PATH)\n",
    "                \n",
    "            # Loop over lambas and save plot and results in each iteration\n",
    "\n",
    "            for lam in lambdas_all:\n",
    "\n",
    "                # Model definition\n",
    "                ################################################################################\n",
    "\n",
    "                MODEL = f\"{NET}_{DATA}\" # 'wrn16-8'\n",
    "                LA = lam \n",
    "                print(f'Starting run with lambda={LA:.2e} with depth={DEPTH}, warmup={WARMUP} and init_lr={INIT_LR:.1e}')\n",
    "\n",
    "                ################################################################################\n",
    "\n",
    "                # Directories and saving paths\n",
    "                fmt_la = f\"{LA:.1e}\"\n",
    "                RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-opt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{warmup_key}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "                RUN_PATH = os.path.join(NET_FILE_PATH, RUN_NAME)\n",
    "                # Create dir\n",
    "                if not os.path.exists(RUN_PATH):\n",
    "                    os.makedirs(RUN_PATH)\n",
    "\n",
    "                ################################################################################\n",
    "\n",
    "                # Set seed\n",
    "                np.random.seed(SEED)\n",
    "                random.seed(SEED)\n",
    "                tf.random.set_seed(SEED)\n",
    "\n",
    "                # Callbacks\n",
    "                #early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "                custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "                print_lr_cb = PrintLRCallback()\n",
    "                terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "                early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "                \n",
    "                # Initialization (depth-dependent scale)\n",
    "                MINPROD = 3e-3\n",
    "                INIT    = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) #tf.keras.initializers.HeNormal\n",
    "\n",
    "                # Define model\n",
    "                print(f\"Architecture = {NET}\")\n",
    "                if NET == 'wrn168':\n",
    "                    hadamard_net = WideResNet(dep=DEP, k=WIDE, input_shape = (IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes = CLASS_NUM,\\\n",
    "                                              depth=DEPTH,init_type=INIT_TYPE, init = INIT, la=LA, use_bias=USE_BIAS,\\\n",
    "                                              factorize_bias=FACTORIZE_BIAS)\n",
    "                elif NET == 'resnet34':\n",
    "                    hadamard_net  = hadamard_resnet34(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "                                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "                                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "                elif NET == 'resnet18':\n",
    "                    hadamard_net  = hadamard_resnet18(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "                                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "                                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "                elif NET == 'vgg16':\n",
    "                    hadamard_net = hadamard_vgg16(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,\\\n",
    "                                                 init_type=INIT_TYPE, init=INIT, la=LA,\\\n",
    "                                                 input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "                elif NET == 'vgg19':\n",
    "                    hadamard_net = hadamard_vgg19(use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS, depth=DEPTH,init_type=INIT_TYPE, init=INIT, la=LA, input_shape=(IMG_ROWS,IMG_COLS,IMG_CHANNELS), n_classes=CLASS_NUM)\n",
    "                else:\n",
    "                    print(\"Architecture not supported!\")\n",
    "\n",
    "                # Pretrain optimizer\n",
    "                optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                                          dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                                          large_lr_start=LARGE_LRSTART, warmup = WARMUP, warmup_eps = WARMUP_EPS,\\\n",
    "                                          warmup_init_lr=WARMUP_INIT_LR)\n",
    "\n",
    "                # Compile model\n",
    "                hadamard_net.compile(optimizer=optimizer,\n",
    "                               loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "                               metrics=['accuracy'])\n",
    "\n",
    "                print(hadamard_net.summary())\n",
    "\n",
    "                ## Pre-training\n",
    "                print(f'Starting pretraining with lambda={LA:.2e}')\n",
    "                pre_hist = hadamard_net.fit(trainflow_aug,\n",
    "                                steps_per_epoch=steps_per_ep,\n",
    "                                batch_size = BATCH_SIZE,\n",
    "                                validation_data=val_flow,\n",
    "                                validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "                                epochs=EPOCHS,\n",
    "                                verbose=VERBOSE,\n",
    "                                callbacks=[custom_sparsity_callback, print_lr_cb, terminate_nan_cb, early_abort_cb])\n",
    "\n",
    "                # Calculate sparsity and compression rate at last epoch\n",
    "                pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "\n",
    "                # Save pretraining trajectory\n",
    "                pre_total, pre_epochs = process_sparsity_callback(hist = pre_hist, hadamard_cb = custom_sparsity_callback, lr_cb = print_lr_cb)\n",
    "                pre_total.to_csv(os.path.join(RUN_PATH, 'pre_cb_total.csv'), index=False)\n",
    "                pre_epochs.to_csv(os.path.join(RUN_PATH, 'pre_cb_epochs.csv'), index=False)\n",
    "                print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "\n",
    "                # Construct and save plot training metrics trajectories\n",
    "                create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name='pre_trajectory.pdf')\n",
    "\n",
    "                # Evaluate after pretraining\n",
    "                pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "                print('\\nTest loss', pretrain_loss)\n",
    "                print('Test accuracy', pretrain_acc)\n",
    "                print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "                print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "                print('Misalignment (pretrain)', pretrain_misalignment)\n",
    "                print('Total L2 norm (pretrain)', pre_l2)\n",
    "\n",
    "                # Initialize df to store results\n",
    "                pretrain_res_df = pd.DataFrame(columns=['Pre Opt', 'Depth', 'Lambda', 'Init Type', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                                        'Pre Epochs', 'Pre Loss','Pre Acc', 'Pre Sparsity', 'Pre CR', 'Pre Misalign', 'Pre L2'])\n",
    "                # Store formatted results in dict\n",
    "                pretrain_res_dict = {\n",
    "                    'Pre Opt': PRETRAIN_OPT,\n",
    "                    'Depth': int(DEPTH),\n",
    "                    'Lambda': f'{LA:.2e}',\n",
    "                    'Init Type': INIT_TYPE,\n",
    "                    'Init LR': f'{INIT_LR:.2e}',\n",
    "                    'LR Schedule': LR_SCHEDULE,\n",
    "                    'Batch size': int(BATCH_SIZE), \n",
    "                    'Pre Epochs': int(EPOCHS),\n",
    "                    'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "                    'Pre Acc': f'{pretrain_acc * 100:.4f}%',\n",
    "                    'Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "                    'Pre CR': f'{pretrain_compression_rate:.2f}', \n",
    "                    'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "                    'Pre L2': f'{pre_l2:.4f}'\n",
    "                }\n",
    "\n",
    "                # Append results to df\n",
    "                pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "                # Save df to CSV if new pretrain \n",
    "                pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "                if os.path.exists(pretrain_csv_file_path):\n",
    "                    print('Pretrain results exist already, skipping saving of results')\n",
    "                else:\n",
    "                    pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "                    print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "                print(\"\\nPretraining Results:\")\n",
    "                print(pretrain_res_df)\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
