{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T14:06:57.604712Z",
     "iopub.status.busy": "2025-05-04T14:06:57.604172Z",
     "iopub.status.idle": "2025-05-04T14:07:18.574906Z",
     "shell.execute_reply": "2025-05-04T14:07:18.574336Z",
     "shell.execute_reply.started": "2025-05-04T14:06:57.604652Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TF_ENABLE_ONEDNN_OPTS\"] = \"0\"\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "    \n",
    "# Fixing cannot import etrees from lxml error by removing all folders and freshly reinstalling\n",
    "os.system(\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-*.egg-info && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-5.3.1.dist-info && \"\n",
    "    \"pip install --no-cache-dir --upgrade --force-reinstall lxml\"\n",
    ")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns, StructuredSparsityCallback\n",
    "from models       import WideResNet, hadamard_resnet18, hadamard_resnet20, hadamard_WRN_16_4\n",
    "from models       import hadamard_vgg16, str_hadamard_vgg16, vanilla_vgg16, vanilla_vgg16_individual, vanilla_vgg19_individual,\\\n",
    "                         vanilla_vgg16_reg\n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                      threshold_model_weights, flatten_and_filter_weights, create_and_save_trajectory_plot, get_flops_and_profile\n",
    "\n",
    "################################################################################\n",
    "# Define VGG model architecture and training parameters\n",
    "\n",
    "# Hadamard\n",
    "DEPTH               = 2\n",
    "LA                  = 3e-4\n",
    "INIT_TYPE           ='equivar' #vanilla, equivar, root, ones\n",
    "INIT_REST           = tf.keras.initializers.Ones() \n",
    "MINPROD             = 3e-3\n",
    "INIT                = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "KERNEL_INITIALIZER  = tf.keras.initializers.HeNormal() \n",
    "MULTFAC_INITIALIZER = tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "#FACTORIZE_BIAS      = False\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 250 #200\n",
    "INIT_LR            = 0.3 if PRETRAIN_OPT == 'sgd' else 1e-2 # depth 2=0.3, depth3 3 = 0.4, depth 4 = 0.6\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Finetuning\n",
    "DO_FINETUNING      = False\n",
    "FINETUNE_OPT       = 'sgd' # sgd, adam\n",
    "FINE_SCHEDULE      = 'cosine' # piecewise, constant, cosine\n",
    "FINETUNE_EPOCHS    = 50     if FINETUNE_OPT == 'sgd' else 20\n",
    "FINE_LA            = 1 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "FINE_LR            = 7e-3   if FINETUNE_OPT == 'sgd' else 2e-4\n",
    "FINETUNE_ALPHA     = 1e-5   if FINETUNE_OPT == 'sgd' else 1e-5\n",
    "\n",
    "# Input settings\n",
    "CLASS_NUM          = 10\n",
    "IMG_ROWS, IMG_COLS = 32, 32\n",
    "IMG_CHANNELS       = 3\n",
    "\n",
    "# Misc\n",
    "PAT                = 1000\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 20\n",
    "FINE_GRACE         = 10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 2\n",
    "compression_grid   = [10,50,75,100,200,500,700,800,1000,\n",
    "                      1250,1500,1750,2000,2500,3000,3500,4000,4500,5000,6000,7000,8000,9000,10000,\n",
    "                      15000,20000,25000,30000,40000,50000,75000,100000]\n",
    "lambdas_all        = [0, 1e-6, 5e-6,  #0-2\n",
    "                       1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5, #3-11\n",
    "                       1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #12-20\n",
    "                       1e-3, 2e-3, 5e-3] #21-23\n",
    "lambdas_10          = [1e-1, 1e-2, 1e-3, 5e-4]\n",
    "lambdas             = lambdas_10\n",
    "\n",
    "# Deirectories and saving paths\n",
    "VGG_FILE_PATH = './results/vgg/l21_cifar10/'\n",
    "\n",
    "print('Defining configs successful!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-05-04T14:07:18.576283Z",
     "iopub.status.busy": "2025-05-04T14:07:18.575914Z",
     "iopub.status.idle": "2025-05-04T14:07:24.850363Z",
     "shell.execute_reply": "2025-05-04T14:07:24.849743Z",
     "shell.execute_reply.started": "2025-05-04T14:07:18.576266Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Load data\n",
    "(X_train, Y_train), (X_test, Y_test) = cifar10.load_data()\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1, 2))  \n",
    "    stds = np.std(dataset, axis=(0, 1, 2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "# mean = [125.3, 123.0, 113.9], std  = [63.0,  62.1,  66.7]\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.05,shuffle = True)\n",
    "\n",
    "# Color preprocessing\n",
    "X_train, X_test, X_val = color_preprocessing(X_train, X_test, X_val)\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Pre-process labels\n",
    "encoder = OneHotEncoder()\n",
    "encoder.fit(Y_train)\n",
    "Y_train = encoder.transform(Y_train).toarray()\n",
    "Y_test = encoder.transform(Y_test).toarray()\n",
    "Y_val =  encoder.transform(Y_val).toarray()\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n",
    "\n",
    "# Define augmentation and data generators and fit to data\n",
    "print('Using real-time data augmentation.')\n",
    "aug_datagen = ImageDataGenerator(horizontal_flip=True,width_shift_range=0.125,height_shift_range=0.125,\n",
    "                                 rotation_range=15,fill_mode='reflect')\n",
    "aug_datagen.fit(X_train)\n",
    "trainflow_aug = aug_datagen.flow(X_train, Y_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "val_datagen = ImageDataGenerator()\n",
    "val_datagen.fit(X_val)\n",
    "val_flow = val_datagen.flow(X_val, Y_val, batch_size=BATCH_SIZE)\n",
    "steps_per_ep = len(X_train) / BATCH_SIZE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T14:11:21.550049Z",
     "iopub.status.busy": "2025-05-04T14:11:21.549507Z",
     "iopub.status.idle": "2025-05-04T18:54:55.064947Z",
     "shell.execute_reply": "2025-05-04T18:54:55.064220Z",
     "shell.execute_reply.started": "2025-05-04T14:11:21.550028Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Network-Slimming on CIFAR10 / VGG16\n",
    "Produces one CSV per gamma-regularisation strength in\n",
    "\n",
    "    {VGG_FILE_PATH}/netslim_vgg16-lagamma{lambda-gamma}/network_slimming_results.csv\n",
    "\"\"\"\n",
    "\n",
    "VGG_FILE_PATH = './results/vgg/ns_cifar10/'\n",
    "\n",
    "from pathlib import Path\n",
    "import os, numpy as np, pandas as pd, tensorflow as tf\n",
    "\n",
    "# ─────────────────────────── user hyper‑params ─────────────────────\n",
    "MODEL = 'ns_vgg_16'\n",
    "EPOCHS       = 200\n",
    "INIT_LR = 1e-1\n",
    "LAMBDA_GRID  = [2e-4, 1e-4, 5e-4, 2e-3, 0, 1e-3]          # regularization strength\n",
    "LAMBDA_GRID.reverse()\n",
    "PRUNE_GRID   = [0, .05, .10, .20, .30, .40, .50,\n",
    "                .60, .70, .80, .90, .95, .98, .99]  # pruning ratios\n",
    "MODEL_TAG    = \"netslim_vgg16\"\n",
    "# ────────────────────────────────────────────────────────────────────\n",
    "\n",
    "\n",
    "# ── build baseline VGG‑16 individual to get FLOPs & default counts ─\n",
    "tf.keras.backend.clear_session()\n",
    "baseline = vanilla_vgg16_individual(\n",
    "    input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes=10\n",
    ")\n",
    "baseline.build((None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "BASE_FLOPS, _ = get_flops_and_profile(baseline)\n",
    "\n",
    "conv_layers_base = [l for l in baseline.layers if isinstance(l, tf.keras.layers.Conv2D)]\n",
    "DEFAULT_COUNTS   = [l.filters for l in conv_layers_base]          # 13 values\n",
    "\n",
    "# index ranges of the 5 VGG blocks (2,2,3,3,3 convs)\n",
    "CONV_PER_BLOCK = [2, 2, 3, 3, 3]\n",
    "\n",
    "# ── helpers ────────────────────────────────────────────────────────\n",
    "def nearest_conv(layer):\n",
    "    t = layer.input\n",
    "    while True:\n",
    "        parent = t._keras_history[0]\n",
    "        if isinstance(parent, tf.keras.layers.Conv2D):\n",
    "            return parent\n",
    "        t = parent.input\n",
    "\n",
    "def vgg_bf_from_counts(counts):\n",
    "    \"\"\"counts: list length 13 → nested list [[b1],[b2],…]\"\"\"\n",
    "    bf, idx = [], 0\n",
    "    for num in CONV_PER_BLOCK:\n",
    "        bf.append(counts[idx:idx+num]); idx += num\n",
    "    return bf                                                # stem not used\n",
    "\n",
    "# ── main loop ───────────────────────────────────────────────────\n",
    "for LA_GAMMA in LAMBDA_GRID:\n",
    "    fmt_gam   = f\"{LA_GAMMA:.1e}\"\n",
    "    RUN_PATH  = Path(VGG_FILE_PATH) / f\"{MODEL_TAG}-lagamma{fmt_gam}\"\n",
    "    RUN_PATH.mkdir(parents=True, exist_ok=True)\n",
    "    print(f\"\\nγ‑regularisation λ = {fmt_gam} → {RUN_PATH}\")\n",
    "\n",
    "    np.random.seed(SEED); tf.random.set_seed(SEED)\n",
    "\n",
    "    # train gamma‑regularised VGG‑16\n",
    "    net = vanilla_vgg16_reg(\n",
    "        la=0.0, la_gamma=LA_GAMMA, use_bias=False,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes=10\n",
    "    )\n",
    "    net.compile(\n",
    "        optimizer=get_optimizer(\n",
    "            lr_schedule=LR_SCHEDULE, init_lr=INIT_LR, epochs=EPOCHS,\n",
    "            dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT,\n",
    "            momentum=MOMENTUM, alpha=0, warmup=WARMUP\n",
    "        ),\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=[\"accuracy\"]\n",
    "    )\n",
    "    net.fit(\n",
    "        trainflow_aug, steps_per_epoch=steps_per_ep,\n",
    "        validation_data=val_flow,\n",
    "        validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "        epochs=EPOCHS, verbose=1,\n",
    "        callbacks=[\n",
    "            tf.keras.callbacks.EarlyStopping('val_accuracy', patience=PAT, restore_best_weights=True),\n",
    "            PrintLRCallback(),\n",
    "            tf.keras.callbacks.TerminateOnNaN(),\n",
    "            TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "        ]\n",
    "    )\n",
    "    \n",
    "    RUN_PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    bn_layers   = [l for l in net.layers if isinstance(l, tf.keras.layers.BatchNormalization)]\n",
    "    conv_layers = [nearest_conv(bn) for bn in bn_layers]      # same order\n",
    "    orig_conv_w = [cv.get_weights() for cv in conv_layers]\n",
    "    orig_bn_w   = [bn.get_weights() for bn in bn_layers]\n",
    "\n",
    "    df_gamma = pd.concat([\n",
    "        pd.DataFrame({\n",
    "            \"bn\": i,\n",
    "            \"ch\": np.arange(int(bn.gamma.shape[0])),\n",
    "            \"g\" : bn.gamma.numpy().astype(np.float32)\n",
    "        }) for i, bn in enumerate(bn_layers)\n",
    "    ], ignore_index=True)\n",
    "\n",
    "    total_ch = len(df_gamma)\n",
    "    sweep    = []\n",
    "\n",
    "    for ratio in PRUNE_GRID:\n",
    "        n_prune = int(np.floor(ratio * total_ch))\n",
    "        thresh  = df_gamma[\"g\"].abs().nsmallest(n_prune).max() if n_prune else 0.0\n",
    "        df_gamma[\"keep\"] = df_gamma[\"g\"].abs() > thresh\n",
    "\n",
    "        # in‑place pruning\n",
    "        for bn, cv in zip(bn_layers, conv_layers):\n",
    "            g, b, m, v = bn.get_weights()\n",
    "            idx = np.where(np.abs(g) <= thresh)[0]\n",
    "            if idx.size:\n",
    "                W, *rest = cv.get_weights()\n",
    "                W[..., idx] = 0\n",
    "                cv.set_weights([W] + rest)\n",
    "                g[idx] = 0; b[idx] = 0\n",
    "                bn.set_weights([g, b, m, v])\n",
    "\n",
    "        loss, acc = net.evaluate(X_test, Y_test, verbose=0)\n",
    "\n",
    "        # restore weights\n",
    "        for cv, w in zip(conv_layers, orig_conv_w): cv.set_weights(w)\n",
    "        for bn, w in zip(bn_layers, orig_bn_w):     bn.set_weights(w)\n",
    "\n",
    "        if ratio == 0:\n",
    "            sparse_flops = BASE_FLOPS\n",
    "            bf = vgg_bf_from_counts(DEFAULT_COUNTS)\n",
    "        else:\n",
    "            counts = (df_gamma[df_gamma.keep]\n",
    "                      .groupby(\"bn\")[\"ch\"].count()\n",
    "                      .reindex(range(len(bn_layers)), fill_value=0)\n",
    "                      .clip(lower=1).tolist())\n",
    "            bf = vgg_bf_from_counts(counts)\n",
    "\n",
    "            tf.keras.backend.clear_session()\n",
    "            pruned = vanilla_vgg16_individual(\n",
    "                block_filters=bf,\n",
    "                input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "                n_classes=10\n",
    "            )\n",
    "            pruned.build((None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "            sparse_flops, _ = get_flops_and_profile(pruned)\n",
    "\n",
    "        true_sp   = n_prune / total_ch\n",
    "        comp_rate = 1 / (1 - true_sp) if true_sp < 1 else np.inf\n",
    "        speedup   = BASE_FLOPS / sparse_flops\n",
    "\n",
    "        sweep.append({\n",
    "            \"PruneRatio\"        : f\"{ratio:.2f}\",\n",
    "            \"TrueSparsity\"      : f\"{true_sp:.4f}\",\n",
    "            \"CompressionRate\"   : f\"{comp_rate:.2f}\",\n",
    "            \"TestLoss\"          : f\"{loss:.4f}\",\n",
    "            \"TestAcc\"           : f\"{acc:.4f}\",\n",
    "            \"ReducedFLOPs\"      : int(sparse_flops),\n",
    "            \"TheoreticalSpeedup\": f\"{speedup:.4f}\",\n",
    "            \"BlockFilters\"      : str(bf)\n",
    "        })\n",
    "\n",
    "    \n",
    "    df_res = pd.DataFrame(sweep)\n",
    "    out_path = os.path.join(RUN_PATH, f'network_slimming_results_{MODEL}.csv')\n",
    "    df_res.to_csv(out_path, index=False)\n",
    "    print(df_res)\n",
    "    print(f\"Saved pruning results to {out_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-01T19:26:12.850968Z",
     "iopub.status.busy": "2025-05-01T19:26:12.850432Z",
     "iopub.status.idle": "2025-05-01T21:55:38.725517Z",
     "shell.execute_reply": "2025-05-01T21:55:38.724921Z",
     "shell.execute_reply.started": "2025-05-01T19:26:12.850944Z"
    }
   },
   "outputs": [],
   "source": [
    "# Cifar VGG16 with filter sparsity via direct L21 regularization + filter pruning\n",
    "\n",
    "# Deirectories and saving paths\n",
    "VGG_FILE_PATH = './results/vgg/l21_cifar10/'\n",
    "\n",
    "# Model definition\n",
    "################################################################################\n",
    "MODEL = 'l21_vgg_16'\n",
    "DEPTH = 1\n",
    "LA_GAMMA = 0\n",
    "#LA =  2e-4\n",
    "#print(f'Starting run with depth = {DEPTH} and lambda={LA:.2e}')\n",
    "KERNEL_INITIALIZER  = KERNEL_INITIALIZER #tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = MULTFAC_INITIALIZER #tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "FACTORIZE_BIAS      = False\n",
    "init_lr = 1e-1\n",
    "VERBOSE=1\n",
    "GRACE=20\n",
    "EPOCHS = 200\n",
    "\n",
    "lambdas_10 = [0,1e-4,2e-4,3e-4,5e-4,7e-4,8e-4,9e-4,1e-3,1.5e-3,2e-3,2.5e-3,3e-3,3.3e-3,3.6e-3,4.5e-3,5e-3,5.5e-3,6.5e-3,7e-3,7.1e-3,9e-3]\n",
    "\n",
    "lambdas_10.reverse()\n",
    "\n",
    "# List of initial learning rates to iterate over\n",
    "#print(lambdas_10)\n",
    "\n",
    "for LA in lambdas_10:\n",
    "    print(f'\\nStarting run with learning rate: {init_lr:.2e} and depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "    \n",
    "    # Update run naming and saving paths with current learning rate\n",
    "    fmt_la = f\"{LA:.1e}\"\n",
    "    fmt_la_gamma = f\"{LA_GAMMA:.1e}\"\n",
    "    RUN_NAME = f\"dep{DEPTH}-la{fmt_la}-lagamma{fmt_la_gamma}-preopt-{PRETRAIN_OPT}-{EPOCHS}eps-{LR_SCHEDULE}-lr{init_lr:.1e}-ftune-{FINETUNE_OPT}-flr{FINE_LR:.1e}-feps{FINETUNE_EPOCHS}-{INIT_TYPE}-bs{BATCH_SIZE}\"\n",
    "    RUN_PATH = os.path.join(VGG_FILE_PATH, RUN_NAME)\n",
    "    \n",
    "    # Set seed for reproducibility in each run\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "    \n",
    "    # Define callbacks\n",
    "    early_stopping = EarlyStopping(monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS)\n",
    "    custom_sparsity_callback = StructuredSparsityCallback(save_metrics=SAVE_METRICS, verbose=VERBOSE)\n",
    "    print_lr_cb = PrintLRCallback()\n",
    "    terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "    early_abort_cb = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "    \n",
    "    # Define the model with the given parameters\n",
    "    \n",
    "    l21_net = vanilla_vgg16_reg(\n",
    "        la=LA,\n",
    "        la_gamma=0,\n",
    "        use_bias=USE_BIAS,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        n_classes=CLASS_NUM\n",
    "    )\n",
    "    \n",
    "    \n",
    "    # Create pretrain optimizer with the current initial learning rate\n",
    "    optimizer = get_optimizer(\n",
    "        lr_schedule=LR_SCHEDULE,\n",
    "        init_lr=init_lr,\n",
    "        lr_decay_fact=LR_DECAY_FACT,\n",
    "        epochs=EPOCHS,\n",
    "        dat=X_train,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        opt=PRETRAIN_OPT,\n",
    "        momentum=MOMENTUM,\n",
    "        alpha=0,\n",
    "        large_lr_start=LARGE_LRSTART,\n",
    "        warmup=WARMUP\n",
    "    )\n",
    "    \n",
    "    # Compile model\n",
    "    l21_net.compile(\n",
    "        optimizer=optimizer,\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "    \n",
    "    print(l21_net.summary())\n",
    "    \n",
    "    \n",
    "    # Define paths for saving pretraining weights and model\n",
    "    weights_name = f'{MODEL}_weights.h5'\n",
    "    full_model_path = f'{MODEL}_model.h5'\n",
    "    weights_path = os.path.join(RUN_PATH, weights_name)\n",
    "    \n",
    "    if os.path.exists(weights_path):\n",
    "        print('Existing pretraining weights found: load and evaluate')\n",
    "        l21_net.load_weights(weights_path)\n",
    "        pretrain_sparsity = -1\n",
    "        pretrain_compression_rate = -1\n",
    "        pretrain_misalignment = -1\n",
    "        pre_l2 = -1\n",
    "    else:\n",
    "        print(f'No existing pretraining weights found: use pretraining with depth = {DEPTH} and lambda = {LA:.2e}')\n",
    "        pre_hist = l21_net.fit(\n",
    "            trainflow_aug,\n",
    "            steps_per_epoch=steps_per_ep,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            validation_data=val_flow,\n",
    "            validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "            epochs=EPOCHS,\n",
    "            callbacks=[early_stopping, print_lr_cb, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        \n",
    "    if not os.path.exists(RUN_PATH):\n",
    "        os.makedirs(RUN_PATH)\n",
    "\n",
    "    eps = np.finfo(np.float32).eps\n",
    "\n",
    "    # collect conv layers and compute filter norms once\n",
    "    conv_layers = [l for l in l21_net.layers if isinstance(l, tf.keras.layers.Conv2D)]\n",
    "    records = []\n",
    "    for li, layer in enumerate(conv_layers):\n",
    "        W = layer.get_weights()[0]\n",
    "        norms = np.linalg.norm(W.reshape(-1, W.shape[-1]), axis=0)\n",
    "        for fi, n in enumerate(norms):\n",
    "            records.append({'layer_idx': li, 'filter_idx': fi, 'norm': n})\n",
    "    df_norms = pd.DataFrame(records)\n",
    "\n",
    "    total_filters = len(df_norms)\n",
    "    zero_filters = (df_norms['norm'] < eps).sum()\n",
    "    print(f\"Raw sparsity before pruning: {zero_filters/total_filters:.2%} ({zero_filters}/{total_filters})\")\n",
    "\n",
    "    original_weights = l21_net.get_weights()\n",
    "\n",
    "    # base model flops\n",
    "    base_model = vanilla_vgg16_individual()\n",
    "    base_model.build(input_shape=(None,32,32,3))\n",
    "    base_flops, _ = get_flops_and_profile(base_model)\n",
    "\n",
    "    prune_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.98, 0.99, 0.995]\n",
    "    results = []\n",
    "\n",
    "    for ratio in prune_list:\n",
    "        n_prune = int(np.floor(ratio * total_filters))\n",
    "        true_sp = n_prune / total_filters\n",
    "        cr      = 1/(1-true_sp)\n",
    "        prune_ids = df_norms.nsmallest(n_prune, 'norm').index\n",
    "\n",
    "        # compute block_filters\n",
    "        remaining = []\n",
    "        for li in range(len(conv_layers)):\n",
    "            keep = df_norms[(df_norms.layer_idx==li) & (~df_norms.index.isin(prune_ids))]\n",
    "            remaining.append(len(keep))\n",
    "        conv_per_block = [2,2,3,3,3]\n",
    "        block_filters = []\n",
    "        idx = 0\n",
    "        for num in conv_per_block:\n",
    "            block_filters.append(remaining[idx:idx+num])\n",
    "            idx += num\n",
    "\n",
    "        # prune l21_net in-place\n",
    "        for li, layer in enumerate(conv_layers):\n",
    "            Wb = layer.get_weights()\n",
    "            W = Wb[0].copy()\n",
    "            idxs = df_norms.loc[prune_ids].query('layer_idx==@li')['filter_idx'].tolist()\n",
    "            if idxs:\n",
    "                W[..., idxs] = 0.\n",
    "                if len(Wb)>1:\n",
    "                    b = Wb[1]\n",
    "                    b[idxs] = 0.\n",
    "                    layer.set_weights([W, b])\n",
    "                else:\n",
    "                    layer.set_weights([W])\n",
    "\n",
    "        loss, acc = l21_net.evaluate(X_test, Y_test, verbose=0)\n",
    "        l21_net.set_weights(original_weights)\n",
    "\n",
    "        # sparse model flops\n",
    "        new_model = vanilla_vgg16_individual(load_weights=False,\n",
    "                                             input_shape=(32,32,3),\n",
    "                                             n_classes=10,\n",
    "                                             block_filters=block_filters)\n",
    "        new_model.build(input_shape=(None,32,32,3))\n",
    "        sparse_flops, _ = get_flops_and_profile(new_model)\n",
    "        speedup = base_flops / sparse_flops\n",
    "\n",
    "        results.append({\n",
    "            'PruneRatio':            ratio,\n",
    "            'TrueSparsity':          f'{true_sp:.4f}',\n",
    "            'CompressionRate':       f'{cr:.2f}',\n",
    "            'TestLoss':              loss,\n",
    "            'TestAcc':               acc,\n",
    "            'ReducedFLOPs':          int(sparse_flops),\n",
    "            'TheoreticalSpeedup':    f'{speedup:.4f}',\n",
    "            'BlockFilters':          block_filters\n",
    "        })\n",
    "\n",
    "    df_res = pd.DataFrame(results)\n",
    "    out_path = os.path.join(RUN_PATH, f'pretraining_{MODEL}.csv')\n",
    "    df_res.to_csv(out_path, index=False)\n",
    "    print(df_res)\n",
    "    print(f\"Saved pruning results to {out_path}\")\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
