import os
import tensorflow as tf
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from layers import HadamardDense, HadamardConv2D, StrHadamardDenseV2

# Color normalization pre-processing function (hardcoded cifar-10)
def color_preprocessing_old(X_train,X_test,X_val):
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_val = X_val.astype('float32')
    mean = [125.3, 123.0, 113.9]
    std  = [63.0,  62.1,  66.7]
    for i in range(3):
        X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
        X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]
        X_val[:,:,:,i] = (X_val[:,:,:,i] - mean[i]) / std[i]

    return X_train, X_test, X_val

def color_preprocessing(X_train, X_test, X_val):
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_val = X_val.astype('float32')
    
    # Calculate channel-wise mean of X_train
    mean = np.mean(X_train, axis=(0, 1, 2))
    std = np.std(X_train, axis=(0, 1, 2))
    
    # Normalize images
    for i in range(3):
        X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
        X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]
        X_val[:,:,:,i] = (X_val[:,:,:,i] - mean[i]) / std[i]

    return X_train, X_test, X_val

# Compute sparsity of StrHadamardDenseV2 layer (group-sparsity for all weights outgoing from previous layer units (eg inputs, hidden units)
def compute_input_sparsity(model, depth):
    layer = model.layers[1]
    print(f"First layer is named {layer.name}")
    if not isinstance(layer, StrHadamardDenseV2):
        raise TypeError("Layer is not of type StrHadamardDenseV2")
        
    if hasattr(layer, 'use_bias') and layer.use_bias:
        raise ValueError("Layer uses biases, which is not supported for this computation.")

    weights = layer.get_weights()
    if len(weights) <= 1:
        raise ValueError("Layer does not have the expected number of weights for given depth.")
    
    # Collapse first layer weight objects
    diag_weights = [tf.convert_to_tensor(weight) for weight in weights[1:]]
    stacked_diag_layers = tf.stack(diag_weights, axis=0)
    collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
    print(f"Shape of collapsed diag layers: {collapsed_diag_layers.shape}")
    diag_mat = tf.linalg.diag(collapsed_diag_layers)
    U1 = tf.convert_to_tensor(weights[0])
    W_reconstructed = tf.linalg.matmul(diag_mat, U1)
    print(f"Shape of reconstructed input weights: {W_reconstructed.shape}")
    
    # Compute 2-norm of first-layer weights per feature
    l2_norms = tf.norm(W_reconstructed, axis=1)
    
    # DF to save results to
    df = pd.DataFrame({
        'input': [f'input{i+1}' for i in range(U1.shape[0])],
        'l2_norm': l2_norms.numpy()
    })
    
    # Compute binary indicator for feature sparsity and other metrics
    df['sparse'] = (df['l2_norm'] < np.finfo(np.float32).eps).astype(int)
    overall_sparsity = df['sparse'].mean()
    overall_cr = 1 / (1 - overall_sparsity)
    min_l2_norm = df['l2_norm'].min()
    max_l2_norm = df['l2_norm'].max()
    
    print(f"Overall Input Sparsity: {overall_sparsity}")
    print(f"Overall Input CR: {overall_cr}")
    print(f"Min L2 Norm: {min_l2_norm}")
    print(f"Max L2 Norm: {max_l2_norm}")
    
    return df, overall_sparsity

# Compute sparsity of StrHadamardDense layer (group-sparsity for all weights incoming to hidden unit of layer (from any previous-layer unit)
def compute_strdense_sparsity(model, depth):
    layer = model.layers[1]
    print(f"First layer is named {layer.name}")
    if not isinstance(layer, StrHadamardDense):
        raise TypeError("Layer is not of type StrHadamardDense")
        
    #if hasattr(layer, 'use_bias') and layer.use_bias:
    #    raise ValueError("Layer currently uses biases, which is not supported for this computation.")

    weights = layer.get_weights()
    if len(weights) <= 1:
        raise ValueError("Layer does not have the expected number of weights for given depth.")
    
    # Collapse first layer weight objects
    diag_weights = [tf.convert_to_tensor(weight) for weight in weights[1:]]
    stacked_diag_layers = tf.stack(diag_weights, axis=0)
    collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
    print(f"Shape of collapsed diag layers: {collapsed_diag_layers.shape}")
    diag_mat = tf.linalg.diag(collapsed_diag_layers)
    U1 = tf.convert_to_tensor(weights[0])
    W_reconstructed = tf.linalg.matmul(U1, diag_mat)
    print(f"Shape of reconstructed input weights: {W_reconstructed.shape}")
    
    # Compute 2-norm of first-layer weights per hidden unit of layer
    l2_norms = tf.norm(W_reconstructed, axis=0)
    
    # DF to save results to
    df = pd.DataFrame({
        'input': [f'input{i+1}' for i in range(U1.shape[0])],
        'l2_norm': l2_norms.numpy()
    })
    
    # Compute binary indicator for feature sparsity and other metrics
    df['sparse'] = (df['l2_norm'] < np.finfo(np.float32).eps).astype(int)
    overall_sparsity = df['sparse'].mean()
    overall_cr = 1 / (1 - overall_sparsity)
    min_l2_norm = df['l2_norm'].min()
    max_l2_norm = df['l2_norm'].max()
    
    print(f"Overall Input Sparsity: {overall_sparsity}")
    print(f"Overall Input CR: {overall_cr}")
    print(f"Min L2 Norm: {min_l2_norm}")
    print(f"Max L2 Norm: {max_l2_norm}")
    
    return df, overall_sparsity


# Compute thresholds corresponding to compression rates (positive weighs)
def compute_thresholds(cr_vals, weight_obj):
    # Convert compression rates to sparsity values
    sparsity_values = [(1 - 1/cr) for cr in cr_vals]
    # Compute quantiles corresponding to sparsity values
    quantiles = [100 * s for s in sparsity_values]
    # Compute thresholds for each quantile
    thresholds = [np.percentile(weight_obj, q) for q in quantiles]
    thresholds = np.array(thresholds)
    
    return thresholds

# Flatten non-empty weight objects
def flatten_and_filter_weights(weights_list):
    flattened_weights = []
    for weights in weights_list:
        if isinstance(weights, np.ndarray) and weights.size > 0:
            # If non-empty numpy array, flatten, take the abs value and extend flattened_weights list
            flattened_weights.extend(np.abs(weights).flatten())
        elif isinstance(weights, list) and weights:
            # If non-empty list, recursively process elements
            flattened_weights.extend(flatten_and_filter_weights(weights))
    return flattened_weights

# Explicit non-smooth group lasso penalty
class ExplicitGroupLasso(tf.keras.regularizers.Regularizer):    
    def __init__(self, la=0, group_idx=None, **kwargs):
        super(ExplicitGroupLasso, self).__init__(**kwargs)
        self.la = la
        self.group_idx = group_idx
        self.group_shapes = [len(gii) for gii in group_idx]
    
    def __call__(self, x):
        self.gathered_inputs = [tf.gather(x, ind, axis=0) for ind in self.group_idx]
        return self.la * tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(self.gathered_inputs[i]))) 
                       for i in range(len(self.gathered_inputs))])
    
class GroupLassoRegularizer(tf.keras.regularizers.Regularizer):    
    def __init__(self, lam=0, axis=0, **kwargs):
        super(GroupLassoRegularizer, self).__init__(**kwargs)
        self.lam = lam
        self.axis = axis
    
    def __call__(self, x):
        # For dense layers: groups are incoming features (default axis=0)
        # For conv2d layers [kernel_height, kernel_width, input_channels, output_channels]:
        # - Use axis=2 to regularize input channel groups
        # - Use axis=3 to regularize output channel (filter) groups
        group_norms = tf.sqrt(tf.reduce_sum(tf.square(x), axis=self.axis))
        return self.lam * tf.reduce_sum(group_norms)

# Custom CosineDecay schedule with linear warmup
class WarmupCosineDecayOld(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, init_lr, warmup_steps, warmup_target, total_steps, alpha=0.0):
        """
        Initializes the WarmupCosineDecay scheduler.
        
        Parameters:
        - init_lr: Initial learning rate.
        - warmup_steps: Number of steps to increase the learning rate linearly from init_lr to warmup_target.
        - warmup_target: Learning rate at the end of the linear warmup. 
        - total_steps: Total number of steps in the schedule after warmup.
        - alpha: Minimum learning rate value as a fraction of init_lr.
        """
        super().__init__()
        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_target = warmup_target
        self.total_steps = total_steps
        self.alpha = alpha

    def __call__(self, step):
        step = tf.cast(step, tf.float32)  # Cast step to float
        if step < self.warmup_steps:
            return self.init_lr + step * (self.warmup_target - self.init_lr) / self.warmup_steps
        else:
            # Apply cosine decay
            decayed = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            decayed = tf.minimum(decayed, 1.0)
            cosine_decayed = 0.5 * (1 + tf.cos(np.pi * decayed))
            decayed = (1 - self.alpha) * cosine_decayed + self.alpha
            return self.warmup_target * decayed

    def get_config(self):
        return {
            "init_lr": self.init_lr,
            "warmup_steps": self.warmup_steps,
            "warmup_target": self.warmup_target,
            "total_steps": self.total_steps,
            "alpha": self.alpha,
        }
    
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, decay_steps, warmup_steps, alpha=0.0, warmup_target=None):
        super(WarmupCosineDecay, self).__init__()
        self.initial_lr = initial_lr
        self.decay_steps = decay_steps
        self.warmup_steps = warmup_steps
        self.alpha = alpha
        self.warmup_target = warmup_target

    def warmup_learning_rate(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        completed_fraction = step / warmup_steps
        total_delta = self.warmup_target - self.initial_lr
        return self.initial_lr + completed_fraction * total_delta

    def decayed_learning_rate(self, step):
        step = tf.cast(step, tf.float32)
        decay_steps = tf.cast(self.decay_steps, tf.float32)
        cosine_decay = 0.5 * (1 + tf.math.cos(tf.constant(np.pi) * step / decay_steps))
        decayed = (1 - self.alpha) * cosine_decay + self.alpha
        return self.warmup_target * decayed

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)

        # Warm-up phase
        warmup_lr = self.warmup_learning_rate(step)
        
        # Decay phase
        decay_lr = self.decayed_learning_rate(step - warmup_steps)

        # Combine warm-up and decay
        learning_rate = tf.cond(step < warmup_steps,
                                lambda: warmup_lr,
                                lambda: decay_lr)
        return learning_rate

    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "decay_steps": self.decay_steps,
            "warmup_steps": self.warmup_steps,
            "alpha": self.alpha,
            "warmup_target": self.warmup_target
        }


# Optimizer and learning rate scheduler
def get_optimizer(dat, lr_schedule = 'cosine', init_lr = 0.1,
                  opt='sgd', momentum=0.9, nesterov=True, epochs=100, batch_size=256,
                  # cosine-specific args
                  alpha=0.0,
                  # piece-wise specific args
                  lr_decay_fact=0.01, large_lr_start=True, first_decay=0.5,
                  # warmup args
                  warmup=False, warmup_eps=5, warmup_init_lr=0.1):
    
    T_max_steps = epochs * int(dat.shape[0] / batch_size)
    warmup_steps = int(warmup_eps * (dat.shape[0] / batch_size))
    
    if lr_schedule == 'constant':
        lr_sched = init_lr
    
    elif lr_schedule == 'cosine':
        if warmup: # throws error?
            #lr_sched = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.01,decay_steps=T_max_steps-warmup_steps,
            #                                                     alpha=alpha,warmup_target = init_lr,warmup_steps = warmup_steps)
            decay_steps = T_max_steps - warmup_steps
            lr_sched = WarmupCosineDecay(initial_lr=warmup_init_lr, warmup_steps=warmup_steps, warmup_target=init_lr, decay_steps=decay_steps, alpha=alpha)
        else:
            lr_sched = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=init_lr,decay_steps=T_max_steps,alpha=alpha)
    
    elif lr_schedule == 'polynomial':
        lr_sched = tf.keras.optimizers.schedules.PolynomialDecay(
            initial_learning_rate=init_lr,
            decay_steps=T_max_steps,
            end_learning_rate=alpha+1e-6,
            power=0.35,
            cycle=False,
            name='PolynomialDecay'
            )
    
    elif lr_schedule == 'piecewise':
        share1 = first_decay
        share2 = first_decay + 0.5 * (1 - first_decay)
        share3 = share2 + 0.25 * (1 - share2)
        decay_step_0, decay_step_1, decay_step_2, decay_step_3 = int(10 * (dat.shape[0] / batch_size)), int(share1 * T_max_steps), int(share2 * T_max_steps), int(share3 * T_max_steps)
        lr_decay0, lr_decay1, lr_decay2, lr_decay3 = np.float32(np.minimum(2 * init_lr, 0.5)), init_lr * lr_decay_fact, init_lr * tf.math.pow(lr_decay_fact,2.0), init_lr * tf.math.pow(lr_decay_fact,3.0)
        step = tf.Variable(0, trainable=False)
        if large_lr_start == True:
            if warmup:
                boundaries = [warmup_steps, decay_step_0, decay_step_1, decay_step_2] # , decay_step_3
                values =     [warmup_init_lr, lr_decay0, init_lr, lr_decay1, lr_decay2] #, lr_decay3
            else:
                boundaries = [decay_step_0, decay_step_1, decay_step_2] #, decay_step_3 
                values =     [lr_decay0, init_lr, lr_decay1, lr_decay2] #, lr_decay3    
        else: 
            if warmup:
                boundaries = [warmup_steps, decay_step_1, decay_step_2] #, decay_step_3
                values =     [warmup_init_lr, init_lr, lr_decay1, lr_decay2] #, lr_decay3
            else:
                boundaries = [decay_step_1,decay_step_2] #, decay_step_3
                values =     [init_lr, lr_decay1, lr_decay2] #, lr_decay3
        lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)
    else:
        raise ValueError("Invalid learning rate schedule configs.")

    # Optimizer selection
    if opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr_sched, momentum=momentum, nesterov=nesterov)
    elif opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_sched)

    return optimizer


# Function to format streamed results of HadamardCallback
def process_sparsity_callback(hist, hadamard_cb, lr_cb):
    # Format trajectory
    train_losses = np.array(hist.history['loss']).reshape(-1, 1)
    train_accs = np.array(hist.history['accuracy']).reshape(-1, 1)
    val_losses = np.array(hist.history['val_loss']).reshape(-1, 1)
    val_accs = np.array(hist.history['val_accuracy']).reshape(-1, 1)
    
    # Format learning rates
    learning_rates = np.array(lr_cb.lr_history[:, 1]).reshape(-1, 1)
    
    # Format sparsity callback results per layer + aggregated as df
    cb_total_res = np.array(hadamard_cb.total_metrics_data)
    merged_total = np.hstack([cb_total_res, train_losses, train_accs, val_losses, val_accs, learning_rates])
    cb_total_names = ['epoch', 'sparsity', 'cr', 'misalignment', 'l2', 'l2_sq_factors', 'min_penalty',
                      'train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr']
    res_total = pd.DataFrame(data=merged_total, columns=cb_total_names)
    res_total['epoch'] = res_total['epoch'].astype(int)
    res_total['sparsity'] = res_total['sparsity'].round(decimals=6)
    res_total['cr'] = res_total['cr'].round(decimals=4)
    res_total['misalignment'] = res_total['misalignment'].round(decimals=20)
    res_total['l2'] = res_total['l2'].round(decimals=4)
    res_total['l2_sq_factors'] = res_total['l2_sq_factors'].round(decimals=20)
    res_total['min_penalty'] = res_total['min_penalty'].round(decimals=20)
    res_total['train_loss'] = res_total['train_loss'].round(decimals=4)
    res_total['train_acc'] = res_total['train_acc'].round(decimals=4)
    res_total['val_loss'] = res_total['val_loss'].round(decimals=4)
    res_total['val_acc'] = res_total['val_acc'].round(decimals=4)
    res_total['lr'] = res_total['lr'].round(decimals=5)

    cb_epoch_res = np.array(hadamard_cb.metrics_data)
    cb_epoch_names = ['epoch', 'layer', 'worb', 'sparsity', 'cr', 'l1', 'l2', 'misalignment', 'l2_sq_factors', 'min_penalty']
    res_epochs = pd.DataFrame(data=cb_epoch_res, columns=cb_epoch_names)
    res_epochs['epoch'] = res_epochs['epoch'].astype(int)
    res_epochs['layer'] = res_epochs['layer'].astype(int)
    res_epochs['worb'] = res_epochs['worb'].astype(int)
    res_epochs['sparsity'] = res_epochs['sparsity'].round(decimals=6)
    res_epochs['cr'] = res_epochs['cr'].round(decimals=4)
    res_epochs['l1'] = res_epochs['l1'].apply(lambda x: format(x, '.2e'))
    res_epochs['l2'] = res_epochs['l2'].apply(lambda x: format(x, '.2e'))
    res_epochs['misalignment'] = res_epochs['misalignment'].round(decimals=20)
    res_epochs['l2_sq_factors'] = res_epochs['l2_sq_factors'].round(decimals=20)
    res_epochs['min_penalty'] = res_epochs['min_penalty'].round(decimals=20)
    
    # Merge learning rate from res_total to res_epochs based on epoch
    res_epochs_lr = pd.merge(res_epochs, res_total[['epoch', 'lr']], on='epoch', how='left')

    return res_total, res_epochs_lr

# Function to format streamed results of HadamardCallback
def process_sparsity_callbackOld(hist, hadamard_cb, lr_cb):
    # Format trajectory
    train_losses = np.array(hist.history['loss']).reshape(-1, 1)
    train_accs = np.array(hist.history['accuracy']).reshape(-1, 1)
    val_losses = np.array(hist.history['val_loss']).reshape(-1, 1)
    val_accs = np.array(hist.history['val_accuracy']).reshape(-1, 1)
    
    # Format learning rates
    learning_rates = np.array(lr_cb.lr_history[:, 1]).reshape(-1, 1)
    
    # Format sparsity callback results per layer + aggregated as df
    cb_total_res = np.array(hadamard_cb.total_metrics_data)
    merged_total = np.hstack([cb_total_res, train_losses, train_accs, val_losses, val_accs, learning_rates])
    cb_total_names = ['epoch', 'sparsity', 'cr', 'misalignment', 'l2', 'l2_sq_factors', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr']
    res_total = pd.DataFrame(data=merged_total, columns=cb_total_names)
    res_total['epoch'] = res_total['epoch'].astype(int)
    res_total['sparsity'] = res_total['sparsity'].round(decimals=6)
    res_total['cr'] = res_total['cr'].round(decimals=4)
    res_total['misalignment'] = res_total['misalignment'].round(decimals=4)
    res_total['l2'] = res_total['l2'].round(decimals=4)
    res_total['l2_sq_factors'] = res_total['l2_sq_factors'].round(decimals=4)
    res_total['train_loss'] = res_total['train_loss'].round(decimals=4)
    res_total['train_acc'] = res_total['train_acc'].round(decimals=4)
    res_total['val_loss'] = res_total['val_loss'].round(decimals=4)
    res_total['val_acc'] = res_total['val_acc'].round(decimals=4)
    res_total['lr'] = res_total['lr'].round(decimals=5)

    cb_epoch_res = np.array(hadamard_cb.metrics_data)
    cb_epoch_names = ['epoch', 'layer', 'worb', 'sparsity', 'cr', 'l1', 'l2', 'misalignment', 'l2_sq_factors']
    res_epochs = pd.DataFrame(data=cb_epoch_res, columns=cb_epoch_names)
    res_epochs['epoch'] = res_epochs['epoch'].astype(int)
    res_epochs['layer'] = res_epochs['layer'].astype(int)
    res_epochs['worb'] = res_epochs['worb'].astype(int)
    res_epochs['sparsity'] = res_epochs['sparsity'].round(decimals=6)
    res_epochs['cr'] = res_epochs['cr'].round(decimals=4)
    res_epochs['l1'] = res_epochs['l1'].apply(lambda x: format(x, '.2e'))
    res_epochs['l2'] = res_epochs['l2'].apply(lambda x: format(x, '.2e'))
    res_epochs['misalignment'] = res_epochs['misalignment'].apply(lambda x: format(x, '.2e'))
    res_epochs['l2_sq_factors'] = res_epochs['l2_sq_factors'].round(decimals=4)
    
    # Merge lr from res_total to res_epochs
    res_epochs_lr = pd.merge(res_epochs, res_total[['epoch', 'lr']], on='epoch', how='left')

    return res_total, res_epochs_lr

# Take output of process_sparsity_callback and creates plot of training metrics
def create_and_save_trajectory_plot(cb_total, run_path, max_loss=10, out_name='pre_trajectory.pdf', show=True):
    
    # Load data
    df = cb_total
    
     # Max train loss during final 70% of training
    total_epochs = df['epoch'].max() + 1
    start_epoch = int(total_epochs * 0.3)  # discard first training phase
    df_subsampled = df[df['epoch'] >= start_epoch]

    # Increase y-axis limit by 25% margin
    max_val_loss_subsampled = df_subsampled['val_loss'].max()
    modified_loss_ylim = max_val_loss_subsampled * 1.25
    if math.isnan(modified_loss_ylim) or math.isinf(modified_loss_ylim):
        modified_loss_ylim = 4
    
    # Define size params
    num_legend_entries=9
    legend_font_size = 24  
    tick_size = 22 
    label_size = 28
    linewidth = 3

    # Plotting Setup
    fig, ax1 = plt.subplots(figsize=(36, 20))

    # Epoch on x-axis
    ax1.set_xlabel('epoch', size=label_size)
    max_epoch = df['epoch'].max()
    ax1.set_xticks(range(0, max_epoch+1, 10))
    ax1.set_yticks([0.05 * i for i in range(0, 21)])
    
    # Add 90% accuracy mark
    ax1.axhline(y=0.9, color='grey', linestyle='--')

    # Plotting train_acc and val_acc on primary y-axis
    ax1.set_ylabel('accuracy, sparsity', color='black', size=label_size)
    line1, = ax1.plot(df['epoch'], df['train_acc'], color='xkcd:blue', label='train acc', linewidth=linewidth, alpha=0.7, linestyle='--')
    line2, = ax1.plot(df['epoch'], df['val_acc'], color='xkcd:azure', label='val acc', linewidth=linewidth, alpha=0.7)
    line0, = ax1.plot(df['epoch'], df['sparsity'], color='xkcd:turquoise', label='sparsity', linewidth=linewidth, alpha=0.5)
    ax1.tick_params(axis='y', labelcolor='black') 

    # 2nd y-axis for train_loss and val_loss
    ax2 = ax1.twinx()
    ax2.set_ylabel('train loss, val loss', color='black', size=label_size)
    line5, = ax2.plot(df['epoch'], df['train_loss'], color='xkcd:red', label='train loss', linewidth=linewidth, alpha=0.7, linestyle='--')
    line6, = ax2.plot(df['epoch'], df['val_loss'], color='xkcd:salmon', label='val loss', linewidth=linewidth, alpha=0.7)
    
    ax2.set_ylim(0, modified_loss_ylim)  # Limit range of loss vals
    ax2.tick_params(axis='y', labelcolor='black') 

    # Third y-axis for compression rate
    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 100))
    ax3.set_ylabel('compression rate', color='black', size=label_size) 
    line3, = ax3.plot(df['epoch'], df['cr'], color='xkcd:green', label='compression rate', linewidth=1.2 * linewidth)
    ax3.tick_params(axis='y', labelcolor='black')  

    # Fourth y-axis for misalignment
    ax4 = ax1.twinx()
    ax4.spines['right'].set_position(('outward', 195)) 
    ax4.set_ylabel('misalignment', color='black', size=label_size) 
    line4, = ax4.plot(df['epoch'], df['misalignment'], color='xkcd:fuchsia', label='misalignment', linewidth=1.2*linewidth)
    ax4.set_ylim(0, 4)  # Limit range of misalignment values
    ax4.tick_params(axis='y', labelcolor='black') 

    # Fifth y-axis for learning rate
    ax5 = ax1.twinx()
    ax5.spines['right'].set_position(('outward', 320))
    ax5.set_ylabel('learning rate', color='black', size=label_size) 
    line7, = ax5.plot(df['epoch'], df['lr'], color='xkcd:gray', label='learning rate', linewidth=linewidth, alpha=0.6, linestyle='dashdot')
    ax5.tick_params(axis='y', labelcolor='black')
    max_lr = df['lr'].max()
    max_lr_rounded = math.ceil(max_lr / 0.05) * 0.05
    ax5.set_yticks([i for i in np.arange(0, max_lr_rounded + 0.05, 0.05)])
    
    # Sixth y-axis for model l2 norm
    ax6 = ax1.twinx()
    ax6.spines['right'].set_position(('outward', 420))
    ax6.set_ylabel('L2 norm', color='black', size=label_size) 
    line8, = ax6.plot(df['epoch'], df['l2'], color='xkcd:orange', label='L2 norm', linewidth=1.2 * linewidth)
    ax6.tick_params(axis='y', labelcolor='black')
    max_l2 = df['l2'].max()
    if math.isnan(df['l2'].max()) or math.isinf(df['l2'].max()):
        max_l2 = 10
    else:
        max_l2_rounded = math.ceil(df['l2'].max())

    # Tick sizes
    ax1.tick_params(axis='both', labelsize=tick_size)
    ax2.tick_params(axis='y', labelsize=tick_size)
    ax3.tick_params(axis='y', labelsize=tick_size)
    ax4.tick_params(axis='y', labelsize=tick_size)
    ax5.tick_params(axis='y', labelsize=tick_size)
    ax6.tick_params(axis='y', labelsize=tick_size)
    
    # Increase plot size at bottom
    plt.subplots_adjust(bottom=0.5) 

    # Create legend
    lines = [line7, line1, line2, line5, line6, line0, line3, line4, line8]
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, 
               loc='upper center', 
               bbox_to_anchor=(0.5, -0.07),  # Adjust as needed
               fancybox=True, framealpha=1,
               fontsize=legend_font_size,
               ncol=num_legend_entries) 

    plt.tight_layout()

    # Save plot
    output_file = os.path.join(run_path, out_name)
    plt.savefig(output_file, bbox_inches='tight')

    # Display plot
    if show:
        plt.show()

    print(f"Saving figure {out_name} was successful.")
    
    
# Threshold models with vanilla layers (no unfactorization beforehand)
# Function to prune vanilla models at specific thresholds and returning pruned model and CR

def threshold_vanilla_weights(model, threshold=np.finfo(np.float32).eps, mode="model"):
    """
    Prune all weights and biases of a TensorFlow model using a threshold.
    
    Args:
    - model: The TensorFlow model to be pruned
    - threshold: The threshold value for pruning
    - mode: "model" or "weights" (default: "model")
    
    Returns:
    - pruned_model: The pruned TensorFlow model (if mode=="model")
    - pruned_weights: List of pruned weight and bias arrays
    - compression_ratio: The compression rate achieved by pruning
    """
    
    # Clone the model to avoid modifying the original
    pruned_model = tf.keras.models.clone_model(model)
    pruned_model.set_weights(model.get_weights())
    
    # Get the weights of the model
    weights = pruned_model.get_weights()
    
    total_params = 0
    zero_params = 0
    pruned_weights = []
    
    for w in weights:
        total_params += w.size
        mask = np.abs(w) > threshold
        pruned_w = w * mask
        zero_params += np.sum(pruned_w == 0)
        pruned_weights.append(pruned_w)
    
    # Calculate sparsity and compression rate
    sparsity = zero_params / total_params
    compression_ratio = 1 / (1 - sparsity) if sparsity != 1 else float('inf')
    
    print(f'Number of zero params ={zero_params}')
    print(f'Total number of params ={total_params}')
    print(f'Sparsity ={sparsity}')
    print(f'Compression ratio = {compression_ratio}')
    
    if mode == "model":
        # Set the pruned weights back to the model
        pruned_model.set_weights(pruned_weights)
        return pruned_model, pruned_weights, compression_ratio, sparsity
    elif mode == "weights":
        return pruned_weights, compression_ratio
    else:
        raise ValueError("Invalid mode. Choose 'model' or 'weights'.")


#Reconstruct original vanilla model by multiplying factors and pruning resulting model to threshold
def threshold_model_weights(model, threshold=100, mode="model"):
    """
    Modify the weights (and optionally biases) of specific layers in a model by thresholding.
    
    Args:
    - model: The model whose weights are to be modified.
    - threshold: The threshold value used for determining small weights.
    - mode: Determines the return type. "weights" returns the modified weights, "model" returns the model with modified weights.
    
    Returns:
    - Either a list of modified weights or the model with modified weights, depending on the 'mode' argument.
    
    This function modifies the weights (and biases, if applicable) of layers in the model that are instances of HadamardDense or HadamardConv2D.
    It computes the sparsity and compression rate of the modified layers. It returns either the modified weights or the modified model.
    """

    modified_weights_list = []
    reconstructed_weights_list = []
    overall_num_small_values = 0
    overall_total_elements = 0
    
    def _compute_sparsity_cr(array):
        num_small_values = np.sum(np.abs(array) < threshold)
        total_elements = np.prod(array.shape)
        sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
        compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0

        return num_small_values, total_elements, sparsity_ratio, compression_rate

    for layer_index, layer in enumerate(model.layers):
        if isinstance(layer, (HadamardDense, HadamardConv2D)):
            layer_weights = layer.get_weights()
            weight_factors = [w for w in layer_weights if len(w.shape) >= 2]

            # Compute reconstructed weights
            reconstructed_weight = np.ones_like(weight_factors[0])
            for w in weight_factors:
                reconstructed_weight *= w

            # Threshold weights
            small_weight_indices = np.abs(reconstructed_weight) < threshold
            weight_factors = [np.where(small_weight_indices, np.zeros_like(w), w) for w in weight_factors]

            # Define modified_weights with thresholded components
            modified_weights = weight_factors
            
            # Define reconstructed weights
            reconstructed_layer_weights = reconstructed_weight

            # Compute sparsity + compression rates for weights
            w_stats = _compute_sparsity_cr(reconstructed_weight)
            overall_num_small_values += w_stats[0]
            overall_total_elements += w_stats[1]

            if layer.use_bias:
                bias_factors = [w for w in layer_weights if len(w.shape) == 1]

                if layer.factorize_bias and len(bias_factors) > 0:
                    # Compute reconstructed bias
                    reconstructed_bias = np.ones_like(bias_factors[0])
                    for b in bias_factors:
                        reconstructed_bias *= b

                    small_bias_indices = np.abs(reconstructed_bias) < threshold
                    bias_factors = [np.where(small_bias_indices, np.zeros_like(b), b) for b in bias_factors]

                    # Append thresholded bias factors to modified_weights
                    modified_weights += bias_factors
                    
                    # Append reconstructed bias to reconstructed layer weights
                    reconstructed_layer_weights += reconstructed_bias

                    # Compute sparsity + compression rates for bias
                    b_stats = _compute_sparsity_cr(reconstructed_bias)
                    overall_num_small_values += b_stats[0]
                    overall_total_elements += b_stats[1]

                elif len(bias_factors) == 1:
                    # If bias not factorized, append original bias to modified_weights and reconstructed layer weights
                    modified_weights += bias_factors
                    reconstructed_layer_weights += bias_factors

            # Append modified_weights of current layer to modified_weights_list
            modified_weights_list.append(modified_weights)
            
            # Append reconstructed layer weights to reconstructed_weights_list
            reconstructed_weights_list.append(reconstructed_layer_weights)

        else:
            # For layers that are not Hadamard layers, append original weights
            modified_weights_list.append(layer.get_weights())
            #reconstructed_weights_list.append(layer.get_weights())

    # Compute overall metrics
    overall_sparsity = overall_num_small_values / overall_total_elements if overall_total_elements > 0 else float('inf')
    overall_compression_rate = overall_total_elements / (overall_total_elements - overall_num_small_values) if overall_num_small_values < overall_total_elements else 0

    print(f"Overall model sparsity: {overall_sparsity * 100:.4f}%")
    print(f"Overall model CR: {overall_compression_rate:.2f}")

    if mode == "model":
        for layer, new_weights in zip(model.layers, modified_weights_list):
            layer.set_weights(new_weights)
        return model, reconstructed_weights_list, overall_compression_rate
    elif mode == "weights":
        return modified_weights_list, reconstructed_weights_list, overall_compression_rate
    else:
        raise ValueError("Invalid mode. Choose 'model' or 'weights'.")


# Calculates total FLOPs for vanilla model (use threshold_model_weights first to reduce Hadamard model to vanilla model)
def get_flops_and_profile(model, batch_size=1):
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph

    concrete = tf.function(lambda inputs: model(inputs))
    concrete_func = concrete.get_concrete_function(
        [tf.TensorSpec([batch_size] + model.inputs[0].shape[1:], model.inputs[0].dtype)])
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)
    with tf.Graph().as_default() as graph:
        tf.graph_util.import_graph_def(graph_def, name='')
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        profile = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)
        
    print(f"Total float operations: {profile.total_float_ops:.2e}")
    return profile.total_float_ops, profile

# Calculate FLOPs and get profile
#total_flops, profile = get_flops_and_profile(pruned_model)

################## ProxSGD TF optimizer 
# credit: https://github.com/optyang/proxsgd/

from tensorflow.keras.optimizers import Optimizer
import tensorflow.keras.backend as K

#----------------------------------------------------------

import tensorflow as tf
from tensorflow.keras.optimizers import Optimizer
from tensorflow.keras import backend as K

class ProxSGD(Optimizer):
    """ProxSGD optimizer (tailored for L1-norm regularization and bound constraint), proposed in
    ProxSGD: Training Structured Neural Networks under Regularization and Constraints, ICLR 2020
    URL: https://openreview.net/forum?id=HygpthEtvr
    """

    def __init__(self, 
                 epsilon_initial=0.06,
                 epsilon_decay=0.5,
                 rho_initial=0.9,
                 rho_decay=0.5,
                 beta=0.999,
                 mu=1e-4,
                 clip_bounds=None,
                 name='ProxSGD',
                 **kwargs):
        """Initialize the optimizer.
        
        Args:
            epsilon_initial (float): Initial learning rate for weights.
            epsilon_decay (float): Learning rate (for weights) decay over each update.
            rho_initial (float): Initial learning rate for momentum.
            rho_decay (float): Learning rate (for momentum) decay over each update.
            beta (float): Second momentum parameter.
            mu (float): Regularization parameter for L1 norm.
            clip_bounds (tuple): A tuple including lower bound and upper bound for clipping.
            name (str): Name of the optimizer.
        """
        super(ProxSGD, self).__init__(name, **kwargs)
        self._set_hyper('epsilon_initial', epsilon_initial)
        self._set_hyper('epsilon_decay', epsilon_decay)
        self._set_hyper('rho_initial', rho_initial)
        self._set_hyper('rho_decay', rho_decay)
        self._set_hyper('beta', beta)
        self._set_hyper('mu', mu)
        self.clip_bounds = clip_bounds

    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, 'v')
            self.add_slot(var, 'r')

    @tf.function
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_dtype = var.dtype.base_dtype
        epsilon_initial = self._get_hyper('epsilon_initial', var_dtype)
        epsilon_decay = self._get_hyper('epsilon_decay', var_dtype)
        rho_initial = self._get_hyper('rho_initial', var_dtype)
        rho_decay = self._get_hyper('rho_decay', var_dtype)
        beta = self._get_hyper('beta', var_dtype)
        mu = self._get_hyper('mu', var_dtype)

        v = self.get_slot(var, 'v')
        r = self.get_slot(var, 'r')

        iteration = tf.cast(self.iterations + 1, var_dtype)
        epsilon = epsilon_initial / (tf.pow(iteration + 4, epsilon_decay))
        rho = rho_initial / (tf.pow(iteration + 4, rho_decay))
        delta = tf.constant(1e-7, dtype=var_dtype)

        v_new = (1 - rho) * v + rho * grad
        r_new = beta * r + (1 - beta) * tf.square(grad)
        tau = tf.sqrt(r_new / (1 - tf.pow(beta, iteration))) + delta

        x_tmp = var - v_new / tau

        if mu is not None:
            mu_normalized = mu / tau
            x_hat = tf.maximum(x_tmp - mu_normalized, 0) - tf.maximum(-x_tmp - mu_normalized, 0)
        else:
            x_hat = x_tmp

        if self.clip_bounds is not None:
            low, up = self.clip_bounds
            x_hat = tf.clip_by_value(x_hat, low, up)

        var_new = var + epsilon * (x_hat - var)

        v.assign(v_new)
        r.assign(r_new)
        var.assign(var_new)

        return tf.no_op()

    def get_config(self):
        config = super(ProxSGD, self).get_config()
        config.update({
            'epsilon_initial': self._serialize_hyperparameter('epsilon_initial'),
            'epsilon_decay': self._serialize_hyperparameter('epsilon_decay'),
            'rho_initial': self._serialize_hyperparameter('rho_initial'),
            'rho_decay': self._serialize_hyperparameter('rho_decay'),
            'beta': self._serialize_hyperparameter('beta'),
            'mu': self._serialize_hyperparameter('mu'),
            'clip_bounds': self.clip_bounds
        })
        return config