{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T08:27:07.234194Z",
     "iopub.status.busy": "2025-05-04T08:27:07.233653Z",
     "iopub.status.idle": "2025-05-04T08:27:20.288218Z",
     "shell.execute_reply": "2025-05-04T08:27:20.287267Z",
     "shell.execute_reply.started": "2025-05-04T08:27:07.234127Z"
    },
    "id": "RPFnfEqRAYA5"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"TF_ENABLE_ONEDNN_OPTS\"] = \"0\"\n",
    "import time\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import keras\n",
    "import pandas as pd\n",
    "from tensorflow.keras.datasets import cifar10, cifar100\n",
    "from tensorflow.keras.models import load_model\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping  #ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l2\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "try:\n",
    "    import pyHSICLasso\n",
    "except:\n",
    "    os.system(\"python -m pip install pyHSICLasso\")\n",
    "try:\n",
    "    import lassonet\n",
    "except:\n",
    "    os.system(\"python -m pip install lassonet\")\n",
    "    \n",
    "# Fixing cannot import etrees from lxml error by removing all folders and freshly reinstalling\n",
    "os.system(\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-*.egg-info && \"\n",
    "    \"rm -rf /usr/local/lib/python3.9/dist-packages/lxml-5.3.1.dist-info && \"\n",
    "    \"pip install --no-cache-dir --upgrade --force-reinstall lxml\"\n",
    ")\n",
    "\n",
    "from layers       import HadamardDense, HadamardConv2D, SparseConv2D, StrHadamardDense, StrConv2D\n",
    "from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d\n",
    "from callbacks    import HadamardCallback, PrintLRCallback, TerminateBadRuns, StructuredSparsityCallback\n",
    "from models       import WideResNet, hadamard_resnet18, hadamard_resnet20, hadamard_WRN_16_4, hadamard_WRN_16_8, hadamard_WRN_28_10\n",
    "from models       import hadamard_vgg16, str_hadamard_vgg16, vanilla_vgg16, vanilla_vgg16_individual, vanilla_vgg19_individual,\\\n",
    "                         vanilla_resnet18, vanilla_resnet18_individual, vanilla_resnet18_reg\n",
    "from utils        import color_preprocessing, compute_thresholds, get_optimizer, process_sparsity_callback,\\\n",
    "                      threshold_model_weights, flatten_and_filter_weights, create_and_save_trajectory_plot, get_flops_and_profile\n",
    "\n",
    "################################################################################\n",
    "# Define VGG model architecture and training parameters\n",
    "\n",
    "# Hadamard\n",
    "DEPTH               = 2\n",
    "LA                  = 3e-4\n",
    "INIT_TYPE           ='equivar' #vanilla, equivar, root, ones\n",
    "INIT_REST           = tf.keras.initializers.Ones() \n",
    "MINPROD             = 3e-3\n",
    "INIT                = TwiceTruncatedNormalInitializer(minprod=MINPROD,depth=DEPTH) \n",
    "KERNEL_INITIALIZER  = tf.keras.initializers.HeNormal()\n",
    "MULTFAC_INITIALIZER = tf.keras.initializers.Ones()\n",
    "USE_BIAS            = False\n",
    "#FACTORIZE_BIAS      = False\n",
    "\n",
    "# Training\n",
    "PRETRAIN_OPT       = 'sgd' # sgd, adam\n",
    "LR_SCHEDULE        = 'cosine' # piecewise, constant, cosine\n",
    "BATCH_SIZE         = 256 #256\n",
    "EPOCHS             = 250 #200\n",
    "INIT_LR            = 0.2 if PRETRAIN_OPT == 'sgd' else 1e-2 # depth 2=0.3, depth3 3 = 0.4, depth 4 = 0.6\n",
    "WARMUP             = False # linear lr warmup from 0.01 to INIT_LR over 5 epochs\n",
    "MOMENTUM           = 0.9    # only for SGD\n",
    "LR_DECAY_FACT      = 0.1   # only for piecewise\n",
    "LARGE_LRSTART      = False  # only for piecewise\n",
    "LABEL_SMOOTHING    = 0.1\n",
    "\n",
    "# Finetuning\n",
    "DO_FINETUNING      = False\n",
    "FINETUNE_OPT       = 'sgd' # sgd, adam\n",
    "FINE_SCHEDULE      = 'cosine' # piecewise, constant, cosine\n",
    "FINETUNE_EPOCHS    = 50     if FINETUNE_OPT == 'sgd' else 20\n",
    "FINE_LA            = 1 * LA if FINETUNE_OPT == 'sgd' else 0.2 * LA\n",
    "FINE_LR            = 7e-3   if FINETUNE_OPT == 'sgd' else 2e-4\n",
    "FINETUNE_ALPHA     = 1e-5   if FINETUNE_OPT == 'sgd' else 1e-5\n",
    "\n",
    "# Input settings\n",
    "CLASS_NUM          = 100\n",
    "IMG_ROWS, IMG_COLS = 32, 32\n",
    "IMG_CHANNELS       = 3\n",
    "\n",
    "# Misc\n",
    "PAT                = 1000\n",
    "RESTORE_WEIGHTS    = False\n",
    "GRACE              = 20\n",
    "FINE_GRACE         = 10\n",
    "MINACC             = (1 / CLASS_NUM) + 0.05\n",
    "SEED               = 123\n",
    "SAVE_METRICS       = True\n",
    "VERBOSE            = 2\n",
    "lambdas_all        = [0, 1e-6, 5e-6,  #0-2\n",
    "                       1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5, #3-11\n",
    "                       1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 7e-4, 8e-4, 9e-4, #12-20\n",
    "                       1e-3, 2e-3, 5e-3] #21-23\n",
    "lambdas_10          = [1e-1, 1e-2, 1e-3, 5e-4]\n",
    "lambdas             = lambdas_10\n",
    "\n",
    "# Deirectories and saving paths\n",
    "RESNET_FILE_PATH = './results/resnet18/cifar100/l21/'\n",
    "\n",
    "print('Defining configs successful!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-05-04T08:27:35.142645Z",
     "iopub.status.busy": "2025-05-04T08:27:35.142153Z",
     "iopub.status.idle": "2025-05-04T08:27:46.273351Z",
     "shell.execute_reply": "2025-05-04T08:27:46.272642Z",
     "shell.execute_reply.started": "2025-05-04T08:27:35.142547Z"
    },
    "id": "mSeZM6CoDbOL",
    "outputId": "80a25c47-d086-4fc8-b35e-9a6d24bdcf2c"
   },
   "outputs": [],
   "source": [
    "# Data loading and pre-processing\n",
    "\n",
    "# Seed\n",
    "np.random.seed(SEED)\n",
    "random.seed(SEED)\n",
    "tf.random.set_seed(SEED)\n",
    "\n",
    "# Load data\n",
    "(X_train, Y_train), (X_test, Y_test) = cifar100.load_data()\n",
    "\n",
    "def calculate_mean_std(dataset):\n",
    "    means = np.mean(dataset, axis=(0, 1, 2))  \n",
    "    stds = np.std(dataset, axis=(0, 1, 2))  \n",
    "    return means, stds\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "\n",
    "# mean = [125.3, 123.0, 113.9], std  = [63.0,  62.1,  66.7]\n",
    "print(\"Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Create train/val/test split\n",
    "X_train, X_val, Y_train, Y_val = train_test_split(X_train,Y_train,test_size = 0.05,shuffle = True)\n",
    "\n",
    "# Color preprocessing\n",
    "X_train, X_test, X_val = color_preprocessing(X_train, X_test, X_val)\n",
    "\n",
    "train_mean, train_std = calculate_mean_std(X_train)\n",
    "test_mean, test_std = calculate_mean_std(X_test)\n",
    "print(\"Normalized Training Set Mean and SD:\", train_mean, train_std)\n",
    "print(\"Normalized Testing Set Mean and SD:\", test_mean, test_std)\n",
    "\n",
    "# Pre-process labels\n",
    "encoder = OneHotEncoder()\n",
    "encoder.fit(Y_train)\n",
    "Y_train = encoder.transform(Y_train).toarray()\n",
    "Y_test = encoder.transform(Y_test).toarray()\n",
    "Y_val =  encoder.transform(Y_val).toarray()\n",
    "\n",
    "# Check sample sizes of split\n",
    "print('Train data shape: ', X_train.shape)\n",
    "print('Train labels shape: ', Y_train.shape)\n",
    "print('Validation data shape: ', X_val.shape)\n",
    "print('Validation labels shape: ', Y_val.shape)\n",
    "print('Test data shape: ', X_test.shape)\n",
    "print('Test labels shape: ', Y_test.shape)\n",
    "\n",
    "# Define augmentation and data generators and fit to data\n",
    "print('Using real-time data augmentation.')\n",
    "aug_datagen = ImageDataGenerator(horizontal_flip=True,width_shift_range=0.125,height_shift_range=0.125,\n",
    "                                 rotation_range=15,fill_mode='reflect')\n",
    "aug_datagen.fit(X_train)\n",
    "trainflow_aug = aug_datagen.flow(X_train, Y_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "val_datagen = ImageDataGenerator()\n",
    "val_datagen.fit(X_val)\n",
    "val_flow = val_datagen.flow(X_val, Y_val, batch_size=BATCH_SIZE)\n",
    "steps_per_ep = len(X_train) / BATCH_SIZE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-03T16:32:04.859964Z",
     "iopub.status.busy": "2025-05-03T16:32:04.859359Z",
     "iopub.status.idle": "2025-05-03T16:33:38.802775Z",
     "shell.execute_reply": "2025-05-03T16:33:38.801339Z",
     "shell.execute_reply.started": "2025-05-03T16:32:04.859940Z"
    }
   },
   "outputs": [],
   "source": [
    "# CIFAR resnet18 with direct L21 regularization and post-hoc pruning\n",
    "\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "\n",
    "# requires:\n",
    "# vanilla_resnet18_reg, vanilla_resnet18_individual\n",
    "# get_optimizer, get_flops_and_profile\n",
    "# trainflow_aug, val_flow, X_train, X_val, X_test, Y_test\n",
    "# IMG_ROWS, IMG_COLS, IMG_CHANNELS, CLASS_NUM\n",
    "# PRETRAIN_OPT, LR_SCHEDULE, FINETUNE_OPT, FINE_LR, FINETUNE_EPOCHS\n",
    "# INIT_TYPE, BATCH_SIZE, SEED, MOMENTUM, WARMUP, RESTORE_WEIGHTS, PAT, MINACC\n",
    "# PrintLRCallback, TerminateBadRuns, RESNET_FILE_PATH\n",
    "\n",
    "MODEL      = 'l21_resnet18'\n",
    "DEPTH      = 1\n",
    "USE_BIAS   = False\n",
    "init_lr    = 0.2\n",
    "GRACE      = 20\n",
    "EPOCHS     = 200\n",
    "LA_GAMMA   = 0\n",
    "\n",
    "lambdas = [0,1e-4,2e-4,5e-4,7e-4,1e-3,2e-3,3e-3,4e-3]\n",
    "lambdas.reverse()\n",
    "prune_list = [0, .01, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .98, .99]\n",
    "\n",
    "for LA in lambdas:\n",
    "    print(f'Starting with lambda={LA:.2e}')\n",
    "    fmt_la      = f\"{LA:.1e}\"\n",
    "    fmt_lagamma = f\"{LA_GAMMA:.1e}\"\n",
    "    RUN_NAME    = (\n",
    "        f\"dep{DEPTH}\"\n",
    "        f\"-la{fmt_la}\"\n",
    "        f\"-lagamma{fmt_lagamma}\"\n",
    "        f\"-preopt-{PRETRAIN_OPT}\"\n",
    "        f\"-{EPOCHS}eps-{LR_SCHEDULE}\"\n",
    "        f\"-lr{init_lr:.1e}\"\n",
    "        f\"-ftune-{FINETUNE_OPT}\"\n",
    "        f\"-flr{FINE_LR:.1e}\"\n",
    "        f\"-feps{FINETUNE_EPOCHS}\"\n",
    "        f\"-{INIT_TYPE}\"\n",
    "        f\"-bs{BATCH_SIZE}\"\n",
    "    )\n",
    "    RUN_PATH = os.path.join(RESNET_FILE_PATH, RUN_NAME)\n",
    "\n",
    "    # reproducibility\n",
    "    np.random.seed(SEED)\n",
    "    random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "\n",
    "    # callbacks\n",
    "    early_stopping   = tf.keras.callbacks.EarlyStopping(\n",
    "        monitor='val_accuracy', patience=PAT, restore_best_weights=RESTORE_WEIGHTS\n",
    "    )\n",
    "    print_lr_cb      = PrintLRCallback()\n",
    "    terminate_nan_cb = tf.keras.callbacks.TerminateOnNaN()\n",
    "    early_abort_cb   = TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "\n",
    "    # 1) Build & train / load the L2,1‐regularized network\n",
    "    l21_net = vanilla_resnet18_reg(\n",
    "        la=LA,\n",
    "        la_gamma=LA_GAMMA,\n",
    "        use_bias=USE_BIAS,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        n_classes=CLASS_NUM\n",
    "    )\n",
    "    l21_net.summary()\n",
    "    optimizer = get_optimizer(\n",
    "        lr_schedule=LR_SCHEDULE,\n",
    "        init_lr=init_lr,\n",
    "        epochs=EPOCHS,\n",
    "        dat=X_train,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        opt=PRETRAIN_OPT,\n",
    "        momentum=MOMENTUM,\n",
    "        alpha=0,\n",
    "        warmup=WARMUP\n",
    "    )\n",
    "    l21_net.compile(\n",
    "        optimizer=optimizer,\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=['accuracy']\n",
    "    )\n",
    "\n",
    "    weights_path = os.path.join(RUN_PATH, f\"res18_la{fmt_la}.h5\")\n",
    "    if os.path.exists(weights_path):\n",
    "        l21_net.load_weights(weights_path)\n",
    "    else:\n",
    "        l21_net.fit(\n",
    "            trainflow_aug,\n",
    "            steps_per_epoch=steps_per_ep,\n",
    "            validation_data=val_flow,\n",
    "            validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "            epochs=EPOCHS,\n",
    "            callbacks=[early_stopping, print_lr_cb, terminate_nan_cb, early_abort_cb]\n",
    "        )\n",
    "        # l21_net.save_weights(weights_path)\n",
    "        \n",
    "    os.makedirs(RUN_PATH, exist_ok=True)\n",
    "\n",
    "    # 2) Gather filter‐norms\n",
    "    conv_layers   = [l for l in l21_net.layers if isinstance(l, tf.keras.layers.Conv2D)]\n",
    "    records       = []\n",
    "    for li, layer in enumerate(conv_layers):\n",
    "        W     = layer.get_weights()[0]\n",
    "        norms = np.linalg.norm(W.reshape(-1, W.shape[-1]), axis=0)\n",
    "        for fj, nv in enumerate(norms):\n",
    "            records.append({'layer_idx': li, 'filter_idx': fj, 'norm': nv})\n",
    "    df_norms       = pd.DataFrame(records)\n",
    "    total_filters = len(df_norms)\n",
    "    original_wts   = [layer.get_weights() for layer in conv_layers]\n",
    "\n",
    "    # 3) Baseline FLOPs\n",
    "    base_model     = vanilla_resnet18_individual(\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "        n_classes=CLASS_NUM\n",
    "    )\n",
    "    base_model.build(input_shape=(None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "    base_flops, _  = get_flops_and_profile(base_model)\n",
    "\n",
    "    # 4) Prune‐and‐evaluate loop\n",
    "    results = []\n",
    "    for ratio in prune_list:\n",
    "        n_pr      = int(np.floor(ratio * total_filters))\n",
    "        prune_ids = df_norms.nsmallest(n_pr, 'norm').index\n",
    "\n",
    "        # 4a) Zero out selected filters in trained model\n",
    "        df_norms['keep'] = ~df_norms.index.isin(prune_ids)\n",
    "        for li, layer in enumerate(conv_layers):\n",
    "            Wb = original_wts[li]\n",
    "            W  = Wb[0].copy()\n",
    "            b  = Wb[1].copy() if len(Wb)>1 else None\n",
    "            to_zero = df_norms.query(\"layer_idx==@li and ~keep\")['filter_idx'].tolist()\n",
    "            if to_zero:\n",
    "                W[..., to_zero] = 0.\n",
    "                if b is not None:\n",
    "                    b[to_zero] = 0.\n",
    "            layer.set_weights([W, b] if b is not None else [W])\n",
    "\n",
    "        loss, acc = l21_net.evaluate(X_test, Y_test, verbose=0)\n",
    "\n",
    "        # restore original weights\n",
    "        for layer, Wb in zip(conv_layers, original_wts):\n",
    "            layer.set_weights(Wb)\n",
    "\n",
    "        # 4b) Build block_filters for small model\n",
    "        # count survivors per conv‐layer\n",
    "        remaining = df_norms.groupby('layer_idx')['keep'].sum()\n",
    "        remaining = remaining.reindex(range(len(conv_layers)), fill_value=0)\n",
    "        remaining = remaining.clip(lower=1)  # never drop to zero\n",
    "\n",
    "        # map into ResNet‐18 groups\n",
    "        stem = int(remaining[0])\n",
    "        groups = []\n",
    "        groups_idx = [\n",
    "            ([1,2],    [3,4]),     # group1\n",
    "            ([5,6,7],  [8,9]),     # group2\n",
    "            ([10,11,12],[13,14]),  # group3\n",
    "            ([15,16,17],[18,19])   # group4\n",
    "        ]\n",
    "        for block_pair in groups_idx:\n",
    "            blk = []\n",
    "            for conv_idxs in block_pair:\n",
    "                cnts = [int(remaining[i]) for i in conv_idxs]\n",
    "                blk.append(cnts)\n",
    "            groups.append(blk)\n",
    "\n",
    "        # sync projection blocks\n",
    "        for gi in range(len(groups)):\n",
    "            for bi in range(len(groups[gi])):\n",
    "                bf = groups[gi][bi]\n",
    "                bf[0] = max(1, bf[0])\n",
    "                bf[-1] = max(1, bf[-1])\n",
    "                if len(bf)==3:\n",
    "                    bf[1] = bf[2]\n",
    "                groups[gi][bi] = bf\n",
    "\n",
    "        block_filters = [max(1, stem), groups]\n",
    "\n",
    "        # 4c) Compute FLOPs of pruned architecture\n",
    "        small = vanilla_resnet18_individual(\n",
    "            block_filters=block_filters,\n",
    "            input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "            n_classes=CLASS_NUM\n",
    "        )\n",
    "        small.build(input_shape=(None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "        sparse_flops, _ = get_flops_and_profile(small)\n",
    "\n",
    "        results.append({\n",
    "            'PruneRatio':         ratio,\n",
    "            'TrueSparsity':       f\"{n_pr/total_filters:.4f}\",\n",
    "            'CompressionRate':    f\"{1/(1-n_pr/total_filters):.2f}\",\n",
    "            'TestLoss':           loss,\n",
    "            'TestAcc':            acc,\n",
    "            'ReducedFLOPs':       int(sparse_flops),\n",
    "            'TheoreticalSpeedup': f\"{base_flops/sparse_flops:.4f}\",\n",
    "            'BlockFilters':       block_filters\n",
    "        })\n",
    "\n",
    "    # 5) Save results\n",
    "    df_res = pd.DataFrame(results)\n",
    "    out_csv = os.path.join(RUN_PATH, f'pruning_resnet18_la{fmt_la}.csv')\n",
    "    df_res.to_csv(out_csv, index=False)\n",
    "    print(df_res)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T08:27:51.770162Z",
     "iopub.status.busy": "2025-05-04T08:27:51.769638Z",
     "iopub.status.idle": "2025-05-04T08:27:54.135481Z",
     "shell.execute_reply": "2025-05-04T08:27:54.134959Z",
     "shell.execute_reply.started": "2025-05-04T08:27:51.770143Z"
    }
   },
   "outputs": [],
   "source": [
    "# Redefine vanilla_resnet18_reg to match exactly vanilla_resnet18_individual\n",
    "\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "from utils import GroupLassoRegularizer\n",
    "\n",
    "# helpers\n",
    "def bn_relu_reg(x, la_gamma):\n",
    "    x = tf.keras.layers.BatchNormalization(\n",
    "        gamma_regularizer=tf.keras.regularizers.L1(l1=la_gamma)\n",
    "    )(x)\n",
    "    return tf.keras.layers.ReLU()(x)\n",
    "\n",
    "def conv_reg(filters, k, s, la, use_bias):\n",
    "    return tf.keras.layers.Conv2D(\n",
    "        filters, k, s, padding='same', use_bias=use_bias,\n",
    "        kernel_initializer='he_normal',\n",
    "        kernel_regularizer=GroupLassoRegularizer(lam=la, axis=3)\n",
    "    )\n",
    "\n",
    "def shortcut_reg(x, filters, stride, mode):\n",
    "    if x.shape[-1] == filters: return x\n",
    "    if mode == 'A':\n",
    "        y = tf.keras.layers.MaxPool2D(1, stride)(x) if stride > 1 else x\n",
    "        return tf.pad(y, [(0,0),(0,0),(0,0),(0, filters - x.shape[-1])])\n",
    "    if mode in ('B', 'B_original'):\n",
    "        y = conv_reg(filters, 1, stride, _la, False)(x)\n",
    "        return tf.keras.layers.BatchNormalization()(y) if mode == 'B_original' else y\n",
    "    raise ValueError\n",
    "\n",
    "def original_block_reg(x, f, s, la, la_gamma, mode):\n",
    "    c1 = conv_reg(f, 3, s, la, False)(x)\n",
    "    c1 = bn_relu_reg(c1, la_gamma)\n",
    "    c2 = conv_reg(f, 3, 1, la, False)(c1)\n",
    "    c2 = tf.keras.layers.BatchNormalization()(c2)\n",
    "    return tf.keras.layers.ReLU()(c2 + shortcut_reg(x, f, s, mode))\n",
    "\n",
    "def preact_block_reg(x, f, s, la, la_gamma, first, mode):\n",
    "    flow = bn_relu_reg(x, la_gamma)\n",
    "    if first: x = flow\n",
    "    c1 = conv_reg(f, 3, s, la, False)(flow)\n",
    "    c1 = bn_relu_reg(c1, la_gamma)\n",
    "    c2 = conv_reg(f, 3, 1, la, False)(c1)\n",
    "    return c2 + shortcut_reg(x, f, s, mode)\n",
    "\n",
    "def Resnet_v_reg(input_shape, n_classes,\n",
    "                 group_sizes=(2,2,2,2), features=(64,128,256,512),\n",
    "                 strides=(1,2,2,2), shortcut_type='B', block_type='original',\n",
    "                 use_bias=False, la=0, la_gamma=0, preact_shortcuts=False,\n",
    "                 name='resnet18_reg'):\n",
    "\n",
    "    global _la; _la = la\n",
    "    mode = 'B_original' if shortcut_type == 'B' else shortcut_type\n",
    "\n",
    "    inp = tf.keras.layers.Input(shape=input_shape)\n",
    "    x = conv_reg(features[0], 3, strides[0], la, use_bias)(inp)\n",
    "    x = bn_relu_reg(x, la_gamma)  # stem BN+ReLU always present\n",
    "\n",
    "    for i, (n_blk, f, s) in enumerate(zip(group_sizes, features, strides)):\n",
    "        for j in range(n_blk):\n",
    "            st = s if j == 0 else 1\n",
    "            if block_type == 'original':\n",
    "                x = original_block_reg(x, f, st, la, la_gamma, mode)\n",
    "            else:\n",
    "                first = (i > 0 and j == 0) or preact_shortcuts\n",
    "                x = preact_block_reg(x, f, st, la, la_gamma, first, mode)\n",
    "\n",
    "    x = tf.keras.layers.GlobalAveragePooling2D()(x)\n",
    "    out = tf.keras.layers.Dense(\n",
    "        n_classes, activation='softmax',\n",
    "        kernel_regularizer=tf.keras.regularizers.L2(l2=la),\n",
    "        use_bias=use_bias\n",
    "    )(x)\n",
    "    return tf.keras.Model(inp, out, name=name)\n",
    "\n",
    "def vanilla_resnet18_reg(load_weights=False, input_shape=(32,32,3),\n",
    "                         n_classes=100, use_bias=False, la=0, la_gamma=0,\n",
    "                         block_type='original', shortcut_type='B',\n",
    "                         preact_shortcuts=False):\n",
    "\n",
    "    model = Resnet_v_reg(\n",
    "        input_shape=input_shape, n_classes=n_classes,\n",
    "        group_sizes=(2,2,2,2), features=(64,128,256,512),\n",
    "        strides=(1,2,2,2), shortcut_type=shortcut_type,\n",
    "        block_type=block_type, use_bias=use_bias,\n",
    "        la=la, la_gamma=la_gamma, preact_shortcuts=preact_shortcuts,\n",
    "        name='vanilla_resnet18_reg'\n",
    "    )\n",
    "    if load_weights:\n",
    "        try:\n",
    "            model.load_weights(os.path.join('saved_models','vanilla_resnet18_reg.tf'))\n",
    "        except tf.errors.NotFoundError:\n",
    "            print(\"No weights found!\")\n",
    "    return model\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 1) Up‑the‑graph walk: BatchNorm mapped to nearest preceding Conv2D\n",
    "# ------------------------------------------------------------------\n",
    "def nearest_conv(layer: tf.keras.layers.Layer) -> tf.keras.layers.Conv2D:\n",
    "    \"\"\"Return the first Conv2D ancestor of `layer` in the functional graph.\"\"\"\n",
    "    t = layer.input\n",
    "    while True:\n",
    "        parent, _, _ = t._keras_history          # (layer, node_idx, tensor_idx)\n",
    "        if isinstance(parent, tf.keras.layers.Conv2D):\n",
    "            return parent\n",
    "        t = parent.input                         # walk one edge further up\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 2) summary printout\n",
    "# ------------------------------------------------------------------\n",
    "def inspect_conv_bn_pairs(model: tf.keras.Model, *, verbose: bool = True) -> pd.DataFrame:\n",
    "    \"\"\"Pair every BatchNorm with its upstream Conv2D and check dimensions.\"\"\"\n",
    "    rows, conv_seen = [], set()\n",
    "\n",
    "    for bn in model.layers:\n",
    "        if not isinstance(bn, tf.keras.layers.BatchNormalization):\n",
    "            continue\n",
    "        cv = nearest_conv(bn)\n",
    "        conv_seen.add(cv)\n",
    "        rows.append({\n",
    "            \"conv\":    cv.name,\n",
    "            \"bn\":      bn.name,\n",
    "            \"filters\": int(cv.filters),\n",
    "            \"gammas\":  int(bn.gamma.shape[-1]),\n",
    "            \"match\":   int(cv.filters) == int(bn.gamma.shape[-1]),\n",
    "        })\n",
    "\n",
    "    df = pd.DataFrame(rows)\n",
    "\n",
    "    if verbose:\n",
    "        print(\"-\" * 60)\n",
    "        print(f\"Conv2D layers referenced : {len(conv_seen)}\")\n",
    "        print(f\"BatchNorm layers scanned  : {len(df)}\")\n",
    "        print(f\"All dimensions match      : {df['match'].all()}\")\n",
    "        if not df['match'].all():\n",
    "            print(\"\\nMISMATCHES:\")\n",
    "            print(df.loc[~df['match'], ['conv', 'bn', 'filters', 'gammas']])\n",
    "        print(\"-\" * 60)\n",
    "    return df\n",
    "\n",
    "# ------------------------------------------------------------------\n",
    "# 3) demo summary\n",
    "# ------------------------------------------------------------------\n",
    "if __name__ == \"__main__\":\n",
    "    # (a) baseline model with default filter counts\n",
    "    baseline = vanilla_resnet18_individual(input_shape=(32, 32, 3), n_classes=100)\n",
    "    baseline.build((None, 32, 32, 3))\n",
    "    inspect_conv_bn_pairs(baseline)\n",
    "    \n",
    "    print(baseline.summary())\n",
    "\n",
    "    # (b) γ‑regularised model\n",
    "    gamma_net = vanilla_resnet18_reg(\n",
    "        la=0.0, la_gamma=1e-3,\n",
    "        input_shape=(32, 32, 3), n_classes=100\n",
    "    )\n",
    "    gamma_net.build((None, 32, 32, 3))\n",
    "    inspect_conv_bn_pairs(gamma_net)\n",
    "    \n",
    "    print(gamma_net.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T08:28:01.970607Z",
     "iopub.status.busy": "2025-05-04T08:28:01.970051Z",
     "iopub.status.idle": "2025-05-04T08:28:01.973782Z",
     "shell.execute_reply": "2025-05-04T08:28:01.973208Z",
     "shell.execute_reply.started": "2025-05-04T08:28:01.970570Z"
    }
   },
   "outputs": [],
   "source": [
    "RESNET_FILE_PATH = './results/resnet18/cifar100/ns/' # change file path to network slimming results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-04T08:28:06.145228Z",
     "iopub.status.busy": "2025-05-04T08:28:06.144482Z"
    }
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Network-Slimming on CIFAR-100 / ResNet-18\n",
    "\"\"\"\n",
    "\n",
    "from pathlib import Path\n",
    "import os, numpy as np, pandas as pd, tensorflow as tf\n",
    "\n",
    "# hyperparams\n",
    "EPOCHS        = 200\n",
    "LAMBDA_GRID   = [1e-4, 5e-4, 1e-3, 0, 2e-3]\n",
    "PRUNE_GRID    = [0, .05, .10, .20, .30, .40, .50,\n",
    "                 .60, .70, .80, .90, .95, .98, .99]\n",
    "MODEL_TAG     = \"netslim_resnet18\"\n",
    "\n",
    "\n",
    "# baseline flops + filter counts\n",
    "tf.keras.backend.clear_session()\n",
    "baseline = vanilla_resnet18_individual(\n",
    "    input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes=CLASS_NUM\n",
    ")\n",
    "baseline.build((None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "BASE_FLOPS, _ = get_flops_and_profile(baseline)\n",
    "\n",
    "bn_layers_base = [l for l in baseline.layers\n",
    "                  if isinstance(l, tf.keras.layers.BatchNormalization)]\n",
    "DEFAULT_COUNTS = [l.input._keras_history[0].filters for l in bn_layers_base]\n",
    "\n",
    "# helper funs\n",
    "def nearest_conv(layer):\n",
    "    t = layer.input\n",
    "    while True:\n",
    "        parent = t._keras_history[0]\n",
    "        if isinstance(parent, tf.keras.layers.Conv2D):\n",
    "            return parent\n",
    "        t = parent.input\n",
    "\n",
    "def bf_from_counts(counts, stem=64):\n",
    "    gi = [([1, 2], [3, 4]),\n",
    "          ([5, 6, 7], [8, 9]),\n",
    "          ([10, 11, 12], [13, 14]),\n",
    "          ([15, 16, 17], [18, 19])]\n",
    "    groups = []\n",
    "    for a, b in gi:\n",
    "        blk = []\n",
    "        for ids in (a, b):\n",
    "            f = [max(1, counts[i]) for i in ids]\n",
    "            if len(f) == 3:\n",
    "                f[1] = f[2]        # keep proj = out\n",
    "            blk.append(f)\n",
    "        groups.append(blk)\n",
    "    return [stem, groups]\n",
    "\n",
    "# loop over lambdas\n",
    "for LA_GAMMA in LAMBDA_GRID:\n",
    "    fmt_gam = f\"{LA_GAMMA:.1e}\"\n",
    "    RUN_NAME = (\n",
    "        f\"{MODEL_TAG}-lagamma{fmt_gam}-preopt-{PRETRAIN_OPT}\"\n",
    "        f\"-{EPOCHS}eps-{LR_SCHEDULE}-bs{BATCH_SIZE}\"\n",
    "    )\n",
    "    RUN_PATH = Path(RESNET_FILE_PATH) / RUN_NAME\n",
    "    RUN_PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    print(f\"\\nγ‑regularisation λ = {fmt_gam} → {RUN_PATH}\")\n",
    "\n",
    "    np.random.seed(SEED)\n",
    "    tf.random.set_seed(SEED)\n",
    "\n",
    "    net = vanilla_resnet18_reg(\n",
    "        la=0.0, la_gamma=LA_GAMMA, use_bias=False,\n",
    "        input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS), n_classes=CLASS_NUM\n",
    "    )\n",
    "    net.compile(\n",
    "        optimizer=get_optimizer(\n",
    "            lr_schedule=LR_SCHEDULE, init_lr=0.2, epochs=EPOCHS,\n",
    "            dat=X_train, batch_size=BATCH_SIZE, opt=PRETRAIN_OPT,\n",
    "            momentum=MOMENTUM, alpha=0, warmup=WARMUP\n",
    "        ),\n",
    "        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING),\n",
    "        metrics=[\"accuracy\"]\n",
    "    )\n",
    "    net.fit(\n",
    "        trainflow_aug, steps_per_epoch=steps_per_ep,\n",
    "        validation_data=val_flow,\n",
    "        validation_steps=X_val.shape[0] // BATCH_SIZE,\n",
    "        epochs=EPOCHS, verbose=1,\n",
    "        callbacks=[\n",
    "            tf.keras.callbacks.EarlyStopping(\n",
    "                monitor=\"val_accuracy\", patience=PAT, restore_best_weights=True),\n",
    "            PrintLRCallback(),\n",
    "            tf.keras.callbacks.TerminateOnNaN(),\n",
    "            TerminateBadRuns(grace=GRACE, minacc=MINACC)\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    bn_layers   = [l for l in net.layers if isinstance(l, tf.keras.layers.BatchNormalization)]\n",
    "    conv_layers = [nearest_conv(bn) for bn in bn_layers]\n",
    "    orig_conv_w = [cv.get_weights() for cv in conv_layers]\n",
    "    orig_bn_w   = [bn.get_weights() for bn in bn_layers]\n",
    "\n",
    "    df_gamma = pd.concat([\n",
    "        pd.DataFrame({\n",
    "            \"bn\": i,\n",
    "            \"ch\": np.arange(int(bn.gamma.shape[0])),\n",
    "            \"g\" : bn.gamma.numpy().astype(np.float32)\n",
    "        }) for i, bn in enumerate(bn_layers)\n",
    "    ], ignore_index=True)\n",
    "\n",
    "    total_ch = len(df_gamma)\n",
    "    sweep = []\n",
    "\n",
    "    for ratio in PRUNE_GRID:\n",
    "        n_prune = int(np.floor(ratio * total_ch))\n",
    "        thresh  = df_gamma[\"g\"].abs().nsmallest(n_prune).max() if n_prune else 0.0\n",
    "        df_gamma[\"keep\"] = df_gamma[\"g\"].abs() > thresh\n",
    "\n",
    "        # apply mask\n",
    "        for bn, cv in zip(bn_layers, conv_layers):\n",
    "            g, b, m, v = bn.get_weights()\n",
    "            idx = np.where(np.abs(g) <= thresh)[0]\n",
    "            if idx.size:\n",
    "                W, *rest = cv.get_weights()\n",
    "                W[..., idx] = 0\n",
    "                cv.set_weights([W] + rest)\n",
    "                g[idx] = 0\n",
    "                b[idx] = 0\n",
    "                bn.set_weights([g, b, m, v])\n",
    "\n",
    "        loss, acc = net.evaluate(X_test, Y_test, verbose=0)\n",
    "\n",
    "        # restore\n",
    "        for cv, w in zip(conv_layers, orig_conv_w):\n",
    "            cv.set_weights(w)\n",
    "        for bn, w in zip(bn_layers, orig_bn_w):\n",
    "            bn.set_weights(w)\n",
    "\n",
    "        if ratio == 0:\n",
    "            sparse_flops = BASE_FLOPS\n",
    "            bf = bf_from_counts(DEFAULT_COUNTS, DEFAULT_COUNTS[0])\n",
    "        else:\n",
    "            counts = (df_gamma[df_gamma.keep]\n",
    "                      .groupby(\"bn\")[\"ch\"].count()\n",
    "                      .reindex(range(len(bn_layers)), fill_value=0)\n",
    "                      .clip(lower=1)\n",
    "                      .tolist())\n",
    "            bf = bf_from_counts(counts, DEFAULT_COUNTS[0])\n",
    "\n",
    "            tf.keras.backend.clear_session()\n",
    "            pruned = vanilla_resnet18_individual(\n",
    "                block_filters=bf,\n",
    "                input_shape=(IMG_ROWS, IMG_COLS, IMG_CHANNELS),\n",
    "                n_classes=CLASS_NUM\n",
    "            )\n",
    "            pruned.build((None, IMG_ROWS, IMG_COLS, IMG_CHANNELS))\n",
    "            sparse_flops, _ = get_flops_and_profile(pruned)\n",
    "\n",
    "        true_sp   = n_prune / total_ch\n",
    "        comp_rate = 1 / (1 - true_sp) if true_sp < 1 else np.inf\n",
    "        speedup   = BASE_FLOPS / sparse_flops\n",
    "\n",
    "        sweep.append({\n",
    "            \"PruneRatio\"        : f\"{ratio:.2f}\",\n",
    "            \"TrueSparsity\"      : f\"{true_sp:.4f}\",\n",
    "            \"CompressionRate\"   : f\"{comp_rate:.2f}\",\n",
    "            \"TestLoss\"          : f\"{loss:.4f}\",\n",
    "            \"TestAcc\"           : f\"{acc:.4f}\",\n",
    "            \"ReducedFLOPs\"      : int(sparse_flops),\n",
    "            \"TheoreticalSpeedup\": f\"{speedup:.4f}\",\n",
    "            \"BlockFilters\"      : str(bf)      \n",
    "        })\n",
    "\n",
    "    df_out = pd.DataFrame(sweep)\n",
    "    out_file = RUN_PATH / f\"network_slimming_results.csv\"\n",
    "    df_out.to_csv(out_file, index=False)\n",
    "    print(df_out.head())\n",
    "    print(f\"Saved results → {out_file}\")\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
