{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T13:30:24.720781Z",
     "iopub.status.busy": "2024-08-21T13:30:24.720414Z",
     "iopub.status.idle": "2024-08-21T13:30:24.734410Z",
     "shell.execute_reply": "2024-08-21T13:30:24.733646Z",
     "shell.execute_reply.started": "2024-08-21T13:30:24.720754Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "import itertools\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, fashion_mnist, mnist\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns\n",
    "from models       import LeNet5, HadamardLeNet5, LeNet300100, HadamardLeNet300100, HadamardLeNet5BN\n",
    "from utils        import get_optimizer, process_sparsity_callback, create_and_save_trajectory_plot\n",
    "\n",
    "try:\n",
    "    import tensorflow_datasets as tfds\n",
    "except:\n",
    "    os.system(\"python -m pip install tensorflow_datasets\")\n",
    "    import tensorflow_datasets as tfds\n",
    "\n",
    "################################################################################\n",
    "\n",
    "# Hadamard\n",
    "DEPTH              = 2\n",
    "LA                 = 5e-4\n",
    "INIT_TYPE          ='equivar' #vanilla, equivar, root, ones\n",
    "#INIT_REST         = tf.keras.initializers.Ones() \n",
    "MINPROD            = 3e-3\n",
    "INIT               = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "#INIT              = tf.keras.initializers.HeNormal()\n",
    "USE_BIAS           = True # only for dense layers\n",
    "FACTORIZE_BIAS     = True\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 75 #200\n",
    "INIT_LR            = 0.15 if PRETRAIN_OPT == 'sgd' else 2e-3 # depth 2=0.2, depth3 3 = 0.5, depth 4 = 0.7\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "\n",
    "# Misc\n",
    "CLASS_NUM          = 10\n",
    "PAT                = 100\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 5\n",
    "FINE_GRACE         = 10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 0\n",
    "RUN_INDICES        = [0,1,2]\n",
    "RUN_INDEX          = RUN_INDICES[0]\n",
    "MODEL_LIST             = ['lenet300100', 'lenet5bn']\n",
    "DATA_LIST               = ['mnist', 'fmnist', 'kmnist']\n",
    "\n",
    "# Directories and saving path\n",
    "LENET_FILE_PATH = f'./results_lenets_mnists/'\n",
    "\n",
    "print('Defining configs successful!')\n",
    "\n",
    "# Lambda grid\n",
    "LAMBDA_LIST = [\n",
    "    0,\n",
    "    1e-6,\n",
    "    1e-5,\n",
    "    1e-4,\n",
    "    2e-4,\n",
    "    4e-4,\n",
    "    6e-4,\n",
    "    7e-4,\n",
    "    8e-4,\n",
    "    9e-4,\n",
    "    1e-3,\n",
    "    1.5e-3,\n",
    "    2e-3,\n",
    "    4e-3,\n",
    "    5e-3,\n",
    "    7e-3,\n",
    "    1e-2,\n",
    "    2e-2,\n",
    "    5e-2,\n",
    "    1e-1\n",
    "]\n",
    "\n",
    "LAMBDA_LIST.reverse()\n",
    "# run with single lambda to drastically reduce size of finished notebook\n",
    "#LAMBDA_LIST = [1e-7]\n",
    "\n",
    "MODEL = 'lenet5bn' # 0:x, 1:lenet5bn\n",
    "DATA = 'mnist' # 0:x, 1:mnist, 2:kmnist\n",
    "RUN_INDICES        = [0]\n",
    "#RUN_INDEX          = RUN_INDICES[0]\n",
    "#print(f'This is seed number {RUN_INDEX+1}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T13:30:34.173247Z",
     "iopub.status.busy": "2024-08-21T13:30:34.172894Z",
     "iopub.status.idle": "2024-08-21T13:30:34.190397Z",
     "shell.execute_reply": "2024-08-21T13:30:34.189331Z",
     "shell.execute_reply.started": "2024-08-21T13:30:34.173205Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.datasets import mnist, fashion_mnist\n",
    "import tensorflow_datasets as tfds\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def load_and_preprocess_mnists(model, data, val_split=0.02, seed=42):\n",
    "    # Set seeds for reproducibility\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    tf.random.set_seed(seed)\n",
    "    \n",
    "    # Load data\n",
    "    if data == 'mnist':\n",
    "        (X_train, Y_train), (X_test, Y_test) = mnist.load_data()\n",
    "    elif data == 'fmnist':\n",
    "        (X_train, Y_train), (X_test, Y_test) = fashion_mnist.load_data()\n",
    "    elif data == 'kmnist':\n",
    "        train_dataset, info_train = tfds.load('kmnist', split='train', as_supervised=True, with_info=True)\n",
    "        test_dataset, info_test = tfds.load('kmnist', split='test', as_supervised=True, with_info=True)\n",
    "        num_classes = info_train.features['label'].num_classes\n",
    "        \n",
    "        if model == 'lenet5bn':\n",
    "            def preprocess_dataset(dataset):\n",
    "                return dataset.map(lambda image, label: (tf.pad(image, [[2, 2], [2, 2], [0, 0]]), label))\n",
    "            \n",
    "            train_dataset = preprocess_dataset(train_dataset)\n",
    "            test_dataset = preprocess_dataset(test_dataset)\n",
    "        \n",
    "        X_train, Y_train = zip(*[(image.numpy(), label.numpy()) for image, label in train_dataset])\n",
    "        X_test, Y_test = zip(*[(image.numpy(), label.numpy()) for image, label in test_dataset])\n",
    "        X_train, Y_train = np.array(X_train), np.array(Y_train)\n",
    "        X_test, Y_test = np.array(X_test), np.array(Y_test)\n",
    "        \n",
    "        print(\"Number of classes:\", num_classes)\n",
    "        print(\"Number of training examples:\", len(X_train))\n",
    "        print(\"Number of testing examples:\", len(X_test))\n",
    "    else:\n",
    "        raise ValueError('Dataset not supported.')\n",
    "    \n",
    "    # Padding for LeNet-5\n",
    "    if model == 'lenet5bn' and data != 'kmnist':\n",
    "        X_train = tf.pad(X_train, [[0, 0], [2,2], [2,2]]).numpy()\n",
    "        X_test = tf.pad(X_test, [[0, 0], [2,2], [2,2]]).numpy()\n",
    "    \n",
    "    # Add channel dimension if necessary\n",
    "    if data != 'kmnist' and model != 'lenet5bn':\n",
    "        X_train = np.expand_dims(X_train, -1)\n",
    "        X_test = np.expand_dims(X_test, -1)\n",
    "    \n",
    "    # Calculate mean and std\n",
    "    def calculate_mean_std(dataset):\n",
    "        return np.mean(dataset, axis=(0, 1, 2)), np.std(dataset, axis=(0, 1, 2))\n",
    "    \n",
    "    train_mean, train_std = calculate_mean_std(X_train)\n",
    "    test_mean, test_std = calculate_mean_std(X_test)\n",
    "    print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "    print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "    \n",
    "    # Create validation split if requested\n",
    "    if val_split > 0:\n",
    "        X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=val_split, shuffle=True, random_state=seed)\n",
    "    else:\n",
    "        X_val, Y_val = None, None\n",
    "    \n",
    "    # Normalize the images to [0, 1]\n",
    "    X_train = X_train / 255.0\n",
    "    X_test = X_test / 255.0\n",
    "    if X_val is not None:\n",
    "        X_val = X_val / 255.0\n",
    "    \n",
    "    train_mean, train_std = calculate_mean_std(X_train)\n",
    "    test_mean, test_std = calculate_mean_std(X_test)\n",
    "    print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "    print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "    \n",
    "    # Print dataset shapes\n",
    "    print(f'\\nModel: {model} on data {data}\\n')\n",
    "    print('Train data shape: ', X_train.shape)\n",
    "    print('Train labels shape: ', Y_train.shape)\n",
    "    if X_val is not None:\n",
    "        print('Validation data shape: ', X_val.shape)\n",
    "        print('Validation labels shape: ', Y_val.shape)\n",
    "    print('Test data shape: ', X_test.shape)\n",
    "    print('Test labels shape: ', Y_test.shape)\n",
    "    \n",
    "    return X_train, Y_train, X_val, Y_val, X_test, Y_test\n",
    "\n",
    "# Usage example:\n",
    "# X_train, Y_train, X_val, Y_val, X_test, Y_test = load_and_preprocess_mnists('lenet5bn', 'mnist', val_split=0.02)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T13:30:37.684618Z",
     "iopub.status.busy": "2024-08-21T13:30:37.683907Z",
     "iopub.status.idle": "2024-08-21T13:30:42.071009Z",
     "shell.execute_reply": "2024-08-21T13:30:42.070162Z",
     "shell.execute_reply.started": "2024-08-21T13:30:37.684579Z"
    }
   },
   "outputs": [],
   "source": [
    "# load data for model\n",
    "X_train, Y_train, X_val, Y_val, X_test, Y_test = load_and_preprocess_mnists(MODEL, DATA, val_split=0.02)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T13:30:50.861797Z",
     "iopub.status.busy": "2024-08-21T13:30:50.860959Z",
     "iopub.status.idle": "2024-08-21T15:26:00.880038Z",
     "shell.execute_reply": "2024-08-21T15:26:00.878720Z",
     "shell.execute_reply.started": "2024-08-21T13:30:50.861768Z"
    }
   },
   "outputs": [],
   "source": [
    "# Loop over all seeds and depths for that DATA and MODEL combination\n",
    "\n",
    "for i, RUN_INDEX in enumerate(RUN_INDICES):\n",
    "    print(f\"Start with seed number {RUN_INDEX}\")\n",
    "    for DEPTH in [2,3,4]:\n",
    "        # Hadamard LeNet-300-100/LeNet5BN on MNIST/FMNIST/KMNIST\n",
    "        for run, LA_ITER in enumerate(LAMBDA_LIST, start=1):\n",
    "            print(f\"This is run {run}/{len(LAMBDA_LIST)}\")\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            if MODEL == 'lenet5bn':\n",
    "                SAVENAME = 'lenet5'\n",
    "            else:\n",
    "                SAVENAME = MODEL\n",
    "            MODELNAME = f'{SAVENAME}_{DATA}'\n",
    "            print({MODELNAME})\n",
    "            DEPTH = DEPTH\n",
    "            LA = LA_ITER #lambdas[0] #LA\n",
    "            print(f'Starting run with lambda={LA:.2e}')\n",
    "            INIT_TYPE = 'equivar'\n",
    "            if MODEL == 'lenet300100' and DATA == 'kmnist':\n",
    "                INIT_LR = 0.6\n",
    "            else:\n",
    "                INIT_LR = 0.15\n",
    "            INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "            #INIT = tf.keras.initializers.HeNormal()\n",
    "            VERBOSE = 0\n",
    "            ################################################################################\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}/{DATA}/{MODELNAME}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "                print(f'Dir {RUN_PATH} not found and created')\n",
    "            else:\n",
    "                print(f'Dir {RUN_PATH} already exists')\n",
    "                      \n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed\n",
    "            np.random.seed(SEED+RUN_INDEX)\n",
    "            random.seed(SEED+RUN_INDEX)\n",
    "            tf.random.set_seed(SEED+RUN_INDEX)\n",
    "\n",
    "            # Callbacks\n",
    "            #early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "            custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "            print_lr_cb = PrintLRCallback(verbose=VERBOSE)\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "            early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "            # Define model\n",
    "            print(MODEL)\n",
    "            if MODEL == 'lenet300100':\n",
    "                hadamard_net = HadamardLeNet300100(input_shape=(28,28,1), n_classes=CLASS_NUM, depth=DEPTH, la=LA, init_type=INIT_TYPE,\\\n",
    "                                        init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "            #hadamard_net = LeNet300100()\n",
    "            if MODEL == 'lenet5bn':\n",
    "                hadamard_net = HadamardLeNet5BN(input_shape=(32,32,1), n_classes=CLASS_NUM, depth=DEPTH, la=LA, init_type=INIT_TYPE,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "                \n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                                      dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                                      large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_net.compile(optimizer=optimizer,\n",
    "                           loss='sparse_categorical_crossentropy',\n",
    "                           metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_net.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_net.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, \n",
    "                                           validation_data=(X_val, Y_val), verbose = VERBOSE,\n",
    "                                           callbacks=[print_lr_cb,custom_sparsity_callback, terminate_nan_cb,early_abort_cb]) #print_lr_cb, \n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Calculate sparsity and compression rate at last epoch\n",
    "            pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "\n",
    "            # Save pretraining trajectory\n",
    "            pre_total, pre_epochs = process_sparsity_callback(hist = pre_hist, hadamard_cb = custom_sparsity_callback, lr_cb = print_lr_cb)\n",
    "            pre_total.to_csv(os.path.join(RUN_PATH, f'pre_cb_total_{RUN_INDEX}.csv'), index=False)\n",
    "            pre_epochs.to_csv(os.path.join(RUN_PATH, f'pre_cb_epochs_{RUN_INDEX}.csv'), index=False)\n",
    "            print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "\n",
    "            # Construct and save plot training metrics trajectories\n",
    "            #create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name=f'pre_trajectory_{RUN_INDEX}.pdf', show=False)\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "            print('Misalignment (pretrain)', pretrain_misalignment)\n",
    "            print('Total L2 norm (pretrain)', pre_l2)\n",
    "\n",
    "            # Initialize df to store results\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Pre Opt', 'Depth', 'Lambda', 'Init Type', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                                    'Pre Epochs', 'Pre Loss','Pre Acc', 'Pre Sparsity', 'Pre CR', 'Pre Misalign', 'Pre L2', 'Run'])\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "              'Pre Opt': PRETRAIN_OPT,'Depth': int(DEPTH),'Lambda': f'{LA:.2e}','Init Type': INIT_TYPE,'Init LR': f'{INIT_LR:.2e}',\n",
    "              'LR Schedule': LR_SCHEDULE,'Batch size': int(BATCH_SIZE), 'Pre Epochs': int(EPOCHS),'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "              'Pre Acc': f'{pretrain_acc * 100:.4f}%','Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "              'Pre CR': f'{pretrain_compression_rate:.2f}', 'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "              'Pre L2': f'{pre_l2:4f}',\n",
    "              'Run': f'{RUN_INDEX}'\n",
    "              }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV if new pretrain \n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODELNAME}_run{RUN_INDEX}.csv')\n",
    "            #if os.path.exists(pretrain_csv_file_path):\n",
    "            #    print('Pretrain results exist already, skipping saving of results')\n",
    "            #else:\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)\n",
    "            \n",
    "print(f'Finished seeds {RUN_INDICES} for {MODEL} and {DATA}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-08-21T15:26:00.883581Z",
     "iopub.status.busy": "2024-08-21T15:26:00.883099Z",
     "iopub.status.idle": "2024-08-21T17:27:14.572906Z",
     "shell.execute_reply": "2024-08-21T17:27:14.572106Z",
     "shell.execute_reply.started": "2024-08-21T15:26:00.883542Z"
    }
   },
   "outputs": [],
   "source": [
    "MODEL = 'lenet5bn' # 0:x, 1:lenet5bn\n",
    "DATA = 'fmnist' # 0:x, 1:mnist, 2:kmnist\n",
    "RUN_INDICES        = [0]\n",
    "#RUN_INDEX          = RUN_INDICES[0]\n",
    "#print(f'This is seed number {RUN_INDEX+1}')\n",
    "\n",
    "# load data for model\n",
    "X_train, Y_train, X_val, Y_val, X_test, Y_test = load_and_preprocess_mnists(MODEL, DATA, val_split=0.02)\n",
    "\n",
    "#####\n",
    "# Loop over all seeds and depths for that DATA and MODEL combination\n",
    "\n",
    "for i, RUN_INDEX in enumerate(RUN_INDICES):\n",
    "    print(f\"Start with seed number {RUN_INDEX}\")\n",
    "    for DEPTH in [2,3,4]:\n",
    "        # Hadamard LeNet-300-100/LeNet5BN on MNIST/FMNIST/KMNIST\n",
    "        for run, LA_ITER in enumerate(LAMBDA_LIST, start=1):\n",
    "            print(f\"This is run {run}/{len(LAMBDA_LIST)}\")\n",
    "            # Model definition\n",
    "            ################################################################################\n",
    "            if MODEL == 'lenet5bn':\n",
    "                SAVENAME = 'lenet5'\n",
    "            else:\n",
    "                SAVENAME = MODEL\n",
    "            MODELNAME = f'{SAVENAME}_{DATA}'\n",
    "            print({MODELNAME})\n",
    "            DEPTH = DEPTH\n",
    "            LA = LA_ITER #lambdas[0] #LA\n",
    "            print(f'Starting run with lambda={LA:.2e}')\n",
    "            INIT_TYPE = 'equivar'\n",
    "            if MODEL == 'lenet300100' and DATA == 'kmnist':\n",
    "                INIT_LR = 0.6\n",
    "            else:\n",
    "                INIT_LR = 0.15\n",
    "            INIT = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH)\n",
    "            #INIT = tf.keras.initializers.HeNormal()\n",
    "            VERBOSE = 0\n",
    "            ################################################################################\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Directories and saving paths\n",
    "            fmt_la = f\"{LA:.1e}\"\n",
    "            RUN_NAME = f\"{MODEL}/{DATA}/{MODELNAME}_dep{DEPTH}-la{fmt_la}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{INIT_LR:.1e}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "            RUN_PATH = os.path.join(LENET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "            # Create dir\n",
    "            if not os.path.exists(RUN_PATH):\n",
    "                os.makedirs(RUN_PATH)\n",
    "                print(f'Dir {RUN_PATH} not found and created')\n",
    "            else:\n",
    "                print(f'Dir {RUN_PATH} already exists')\n",
    "                      \n",
    "\n",
    "            ################################################################################\n",
    "            # Set seed\n",
    "            np.random.seed(SEED+RUN_INDEX)\n",
    "            random.seed(SEED+RUN_INDEX)\n",
    "            tf.random.set_seed(SEED+RUN_INDEX)\n",
    "\n",
    "            # Callbacks\n",
    "            #early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "            custom_sparsity_callback = HadamardCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "            print_lr_cb = PrintLRCallback(verbose=VERBOSE)\n",
    "            terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "            early_abort_cb = TerminateBadRuns(grace=GRACE, minacc = MINACC)\n",
    "\n",
    "            # Define model\n",
    "            print(MODEL)\n",
    "            if MODEL == 'lenet300100':\n",
    "                hadamard_net = HadamardLeNet300100(input_shape=(28,28,1), n_classes=CLASS_NUM, depth=DEPTH, la=LA, init_type=INIT_TYPE,\\\n",
    "                                        init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "            #hadamard_net = LeNet300100()\n",
    "            if MODEL == 'lenet5bn':\n",
    "                hadamard_net = HadamardLeNet5BN(input_shape=(32,32,1), n_classes=CLASS_NUM, depth=DEPTH, la=LA, init_type=INIT_TYPE,\n",
    "                                   init=INIT, use_bias=USE_BIAS, factorize_bias=FACTORIZE_BIAS)\n",
    "                \n",
    "            # Pretrain optimizer\n",
    "            optimizer = get_optimizer(lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, lr_decay_fact=LR_DECAY_FACT, epochs=EPOCHS,\\\n",
    "                                      dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT, momentum=MOMENTUM, alpha=0,\\\n",
    "                                      large_lr_start=LARGE_LRSTART, warmup = WARMUP)\n",
    "\n",
    "            # Compile model\n",
    "            hadamard_net.compile(optimizer=optimizer,\n",
    "                           loss='sparse_categorical_crossentropy',\n",
    "                           metrics=['accuracy'])\n",
    "\n",
    "            print(hadamard_net.summary())\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Training\n",
    "            pre_hist = hadamard_net.fit(x=X_train, y=Y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, \n",
    "                                           validation_data=(X_val, Y_val), verbose = VERBOSE,\n",
    "                                           callbacks=[print_lr_cb,custom_sparsity_callback, terminate_nan_cb,early_abort_cb]) #print_lr_cb, \n",
    "\n",
    "            # Evaluate after training\n",
    "            pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "\n",
    "\n",
    "            ################################################################################\n",
    "\n",
    "            # Calculate sparsity and compression rate at last epoch\n",
    "            pretrain_sparsity, pretrain_compression_rate, pretrain_misalignment, pre_l2 = custom_sparsity_callback.total_metrics_data[-1, 1:5]\n",
    "\n",
    "            # Save pretraining trajectory\n",
    "            pre_total, pre_epochs = process_sparsity_callback(hist = pre_hist, hadamard_cb = custom_sparsity_callback, lr_cb = print_lr_cb)\n",
    "            pre_total.to_csv(os.path.join(RUN_PATH, f'pre_cb_total_{RUN_INDEX}.csv'), index=False)\n",
    "            pre_epochs.to_csv(os.path.join(RUN_PATH, f'pre_cb_epochs_{RUN_INDEX}.csv'), index=False)\n",
    "            print(f'Callback results saved successfully to {RUN_PATH}')\n",
    "\n",
    "            # Construct and save plot training metrics trajectories\n",
    "            #create_and_save_trajectory_plot(cb_total=pre_total, run_path=RUN_PATH, max_loss=10, out_name=f'pre_trajectory_{RUN_INDEX}.pdf', show=False)\n",
    "\n",
    "            # Evaluate after pretraining\n",
    "            pretrain_loss, pretrain_acc = hadamard_net.evaluate(X_test, Y_test)\n",
    "            print('\\nTest loss', pretrain_loss)\n",
    "            print('Test accuracy', pretrain_acc)\n",
    "            print('Sparsity (pretrain)', pretrain_sparsity)\n",
    "            print('Compression rate (pretrain)', pretrain_compression_rate)\n",
    "            print('Misalignment (pretrain)', pretrain_misalignment)\n",
    "            print('Total L2 norm (pretrain)', pre_l2)\n",
    "\n",
    "            # Initialize df to store results\n",
    "            pretrain_res_df = pd.DataFrame(columns=['Pre Opt', 'Depth', 'Lambda', 'Init Type', 'Init LR', 'LR Schedule', 'Batch size',\\\n",
    "                                                    'Pre Epochs', 'Pre Loss','Pre Acc', 'Pre Sparsity', 'Pre CR', 'Pre Misalign', 'Pre L2', 'Run'])\n",
    "            # Store formatted results in dict\n",
    "            pretrain_res_dict = {\n",
    "              'Pre Opt': PRETRAIN_OPT,'Depth': int(DEPTH),'Lambda': f'{LA:.2e}','Init Type': INIT_TYPE,'Init LR': f'{INIT_LR:.2e}',\n",
    "              'LR Schedule': LR_SCHEDULE,'Batch size': int(BATCH_SIZE), 'Pre Epochs': int(EPOCHS),'Pre Loss': f'{pretrain_loss:.3f}',\n",
    "              'Pre Acc': f'{pretrain_acc * 100:.4f}%','Pre Sparsity': f'{pretrain_sparsity * 100:.4f}%',\n",
    "              'Pre CR': f'{pretrain_compression_rate:.2f}', 'Pre Misalign': f'{pretrain_misalignment:.4f}',\n",
    "              'Pre L2': f'{pre_l2:4f}',\n",
    "              'Run': f'{RUN_INDEX}'\n",
    "              }\n",
    "\n",
    "            # Append results to df\n",
    "            pretrain_res_df = pd.concat([pretrain_res_df, pd.DataFrame([pretrain_res_dict])], ignore_index=True)\n",
    "\n",
    "            # Save df to CSV if new pretrain \n",
    "            pretrain_csv_file_path = os.path.join(RUN_PATH, f'pretraining_{MODELNAME}_run{RUN_INDEX}.csv')\n",
    "            #if os.path.exists(pretrain_csv_file_path):\n",
    "            #    print('Pretrain results exist already, skipping saving of results')\n",
    "            #else:\n",
    "            pretrain_res_df.to_csv(pretrain_csv_file_path, index=False)\n",
    "            print(f'Pretrain results saved to {pretrain_csv_file_path}')\n",
    "            print(\"\\nPretraining Results:\")\n",
    "            print(pretrain_res_df)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
