import os
import tensorflow as tf
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
from layers import HadamardDense, HadamardConv2D, StrHadamardDenseV2

# Color normalization pre-processing function (hardcoded cifar-10)
def color_preprocessing_old(X_train,X_test,X_val):
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_val = X_val.astype('float32')
    mean = [125.3, 123.0, 113.9]
    std  = [63.0,  62.1,  66.7]
    for i in range(3):
        X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
        X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]
        X_val[:,:,:,i] = (X_val[:,:,:,i] - mean[i]) / std[i]

    return X_train, X_test, X_val

def color_preprocessing(X_train, X_test, X_val):
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_val = X_val.astype('float32')
    
    # Calculate channel-wise mean of X_train
    mean = np.mean(X_train, axis=(0, 1, 2))
    std = np.std(X_train, axis=(0, 1, 2))
    
    # Normalize images
    for i in range(3):
        X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
        X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]
        X_val[:,:,:,i] = (X_val[:,:,:,i] - mean[i]) / std[i]

    return X_train, X_test, X_val

# Compute sparsity of StrHadamardDenseV2 layer (group-sparsity for all weights outgoing from previous layer units (eg inputs, hidden units)
def compute_input_sparsity(model, depth):
    layer = model.layers[1]
    print(f"First layer is named {layer.name}")
    if not isinstance(layer, StrHadamardDenseV2):
        raise TypeError("Layer is not of type StrHadamardDenseV2")
        
    if hasattr(layer, 'use_bias') and layer.use_bias:
        raise ValueError("Layer uses biases, which is not supported for this computation.")

    weights = layer.get_weights()
    if len(weights) <= 1:
        raise ValueError("Layer does not have the expected number of weights for given depth.")
    
    # Collapse first layer weight objects
    diag_weights = [tf.convert_to_tensor(weight) for weight in weights[1:]]
    stacked_diag_layers = tf.stack(diag_weights, axis=0)
    collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
    print(f"Shape of collapsed diag layers: {collapsed_diag_layers.shape}")
    diag_mat = tf.linalg.diag(collapsed_diag_layers)
    U1 = tf.convert_to_tensor(weights[0])
    W_reconstructed = tf.linalg.matmul(diag_mat, U1)
    print(f"Shape of reconstructed input weights: {W_reconstructed.shape}")
    
    # Compute 2-norm of first-layer weights per feature
    l2_norms = tf.norm(W_reconstructed, axis=1)
    
    # DF to save results to
    df = pd.DataFrame({
        'input': [f'input{i+1}' for i in range(U1.shape[0])],
        'l2_norm': l2_norms.numpy()
    })
    
    # Compute binary indicator for feature sparsity and other metrics
    df['sparse'] = (df['l2_norm'] < np.finfo(np.float32).eps).astype(int)
    overall_sparsity = df['sparse'].mean()
    overall_cr = 1 / (1 - overall_sparsity)
    min_l2_norm = df['l2_norm'].min()
    max_l2_norm = df['l2_norm'].max()
    
    print(f"Overall Input Sparsity: {overall_sparsity}")
    print(f"Overall Input CR: {overall_cr}")
    print(f"Min L2 Norm: {min_l2_norm}")
    print(f"Max L2 Norm: {max_l2_norm}")
    
    return df, overall_sparsity

# Compute sparsity of StrHadamardDense layer (group-sparsity for all weights incoming to hidden unit of layer (from any previous-layer unit)
def compute_strdense_sparsity(model, depth):
    layer = model.layers[1]
    print(f"First layer is named {layer.name}")
    if not isinstance(layer, StrHadamardDense):
        raise TypeError("Layer is not of type StrHadamardDense")
        
    #if hasattr(layer, 'use_bias') and layer.use_bias:
    #    raise ValueError("Layer currently uses biases, which is not supported for this computation.")

    weights = layer.get_weights()
    if len(weights) <= 1:
        raise ValueError("Layer does not have the expected number of weights for given depth.")
    
    # Collapse first layer weight objects
    diag_weights = [tf.convert_to_tensor(weight) for weight in weights[1:]]
    stacked_diag_layers = tf.stack(diag_weights, axis=0)
    collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
    print(f"Shape of collapsed diag layers: {collapsed_diag_layers.shape}")
    diag_mat = tf.linalg.diag(collapsed_diag_layers)
    U1 = tf.convert_to_tensor(weights[0])
    W_reconstructed = tf.linalg.matmul(U1, diag_mat)
    print(f"Shape of reconstructed input weights: {W_reconstructed.shape}")
    
    # Compute 2-norm of first-layer weights per hidden unit of layer
    l2_norms = tf.norm(W_reconstructed, axis=0)
    
    # DF to save results to
    df = pd.DataFrame({
        'input': [f'input{i+1}' for i in range(U1.shape[0])],
        'l2_norm': l2_norms.numpy()
    })
    
    # Compute binary indicator for feature sparsity and other metrics
    df['sparse'] = (df['l2_norm'] < np.finfo(np.float32).eps).astype(int)
    overall_sparsity = df['sparse'].mean()
    overall_cr = 1 / (1 - overall_sparsity)
    min_l2_norm = df['l2_norm'].min()
    max_l2_norm = df['l2_norm'].max()
    
    print(f"Overall Input Sparsity: {overall_sparsity}")
    print(f"Overall Input CR: {overall_cr}")
    print(f"Min L2 Norm: {min_l2_norm}")
    print(f"Max L2 Norm: {max_l2_norm}")
    
    return df, overall_sparsity


# Compute thresholds corresponding to compression rates (positive weighs)
def compute_thresholds(cr_vals, weight_obj):
    # Convert compression rates to sparsity values
    sparsity_values = [(1 - 1/cr) for cr in cr_vals]
    # Compute quantiles corresponding to sparsity values
    quantiles = [100 * s for s in sparsity_values]
    # Compute thresholds for each quantile
    thresholds = [np.percentile(weight_obj, q) for q in quantiles]
    thresholds = np.array(thresholds)
    
    return thresholds

# Flatten non-empty weight objects
def flatten_and_filter_weights(weights_list):
    flattened_weights = []
    for weights in weights_list:
        if isinstance(weights, np.ndarray) and weights.size > 0:
            # If non-empty numpy array, flatten, take the abs value and extend flattened_weights list
            flattened_weights.extend(np.abs(weights).flatten())
        elif isinstance(weights, list) and weights:
            # If non-empty list, recursively process elements
            flattened_weights.extend(flatten_and_filter_weights(weights))
    return flattened_weights

# Explicit non-smooth group lasso penalty
class ExplicitGroupLasso(tf.keras.regularizers.Regularizer):    
    def __init__(self, la=0, group_idx=None, **kwargs):
        super(ExplicitGroupLasso, self).__init__(**kwargs)
        self.la = la
        self.group_idx = group_idx
        self.group_shapes = [len(gii) for gii in group_idx]
    
    def __call__(self, x):
        self.gathered_inputs = [tf.gather(x, ind, axis=0) for ind in self.group_idx]
        return self.la * tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(self.gathered_inputs[i]))) 
                       for i in range(len(self.gathered_inputs))])

# Custom CosineDecay schedule with linear warmup
class WarmupCosineDecayOld(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, init_lr, warmup_steps, warmup_target, total_steps, alpha=0.0):
        """
        Initializes the WarmupCosineDecay scheduler.
        
        Parameters:
        - init_lr: Initial learning rate.
        - warmup_steps: Number of steps to increase the learning rate linearly from init_lr to warmup_target.
        - warmup_target: Learning rate at the end of the linear warmup. 
        - total_steps: Total number of steps in the schedule after warmup.
        - alpha: Minimum learning rate value as a fraction of init_lr.
        """
        super().__init__()
        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_target = warmup_target
        self.total_steps = total_steps
        self.alpha = alpha

    def __call__(self, step):
        step = tf.cast(step, tf.float32)  # Cast step to float
        if step < self.warmup_steps:
            return self.init_lr + step * (self.warmup_target - self.init_lr) / self.warmup_steps
        else:
            # Apply cosine decay
            decayed = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            decayed = tf.minimum(decayed, 1.0)
            cosine_decayed = 0.5 * (1 + tf.cos(np.pi * decayed))
            decayed = (1 - self.alpha) * cosine_decayed + self.alpha
            return self.warmup_target * decayed

    def get_config(self):
        return {
            "init_lr": self.init_lr,
            "warmup_steps": self.warmup_steps,
            "warmup_target": self.warmup_target,
            "total_steps": self.total_steps,
            "alpha": self.alpha,
        }
    
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, decay_steps, warmup_steps, alpha=0.0, warmup_target=None):
        super(WarmupCosineDecay, self).__init__()
        self.initial_lr = initial_lr
        self.decay_steps = decay_steps
        self.warmup_steps = warmup_steps
        self.alpha = alpha
        self.warmup_target = warmup_target

    def warmup_learning_rate(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        completed_fraction = step / warmup_steps
        total_delta = self.warmup_target - self.initial_lr
        return self.initial_lr + completed_fraction * total_delta

    def decayed_learning_rate(self, step):
        step = tf.cast(step, tf.float32)
        decay_steps = tf.cast(self.decay_steps, tf.float32)
        cosine_decay = 0.5 * (1 + tf.math.cos(tf.constant(np.pi) * step / decay_steps))
        decayed = (1 - self.alpha) * cosine_decay + self.alpha
        return self.warmup_target * decayed

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)

        # Warm-up phase
        warmup_lr = self.warmup_learning_rate(step)
        
        # Decay phase
        decay_lr = self.decayed_learning_rate(step - warmup_steps)

        # Combine warm-up and decay
        learning_rate = tf.cond(step < warmup_steps,
                                lambda: warmup_lr,
                                lambda: decay_lr)
        return learning_rate

    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "decay_steps": self.decay_steps,
            "warmup_steps": self.warmup_steps,
            "alpha": self.alpha,
            "warmup_target": self.warmup_target
        }


# Optimizer and learning rate scheduler
def get_optimizer(dat, lr_schedule = 'cosine', init_lr = 0.1,
                  opt='sgd', momentum=0.9, nesterov=True, epochs=100, batch_size=256,
                  # cosine-specific args
                  alpha=0.0,
                  # piece-wise specific args
                  lr_decay_fact=0.1, large_lr_start=True,warmup=False, warmup_eps=5, decay_steps=30000, warmup_init_lr=0.1):
    
    T_max_steps = epochs * int(dat.shape[0] / batch_size)
    warmup_steps = int(warmup_eps * (dat.shape[0] / batch_size))
    
    if lr_schedule == 'constant':
        lr_sched = init_lr
    
    elif lr_schedule == 'cosine':
        if warmup: # throws error?
            #lr_sched = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.01,decay_steps=T_max_steps-warmup_steps,
            #                                                     alpha=alpha,warmup_target = init_lr,warmup_steps = warmup_steps)
            decay_steps = T_max_steps - warmup_steps
            lr_sched = WarmupCosineDecay(initial_lr=warmup_init_lr, warmup_steps=warmup_steps, warmup_target=init_lr, decay_steps=decay_steps, alpha=alpha)
        else:
            lr_sched = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=init_lr,decay_steps=T_max_steps,alpha=alpha)
    
    elif lr_schedule == 'polynomial':
        lr_sched = tf.keras.optimizers.schedules.PolynomialDecay(
            initial_learning_rate=init_lr,
            decay_steps=T_max_steps,
            end_learning_rate=alpha+1e-6,
            power=0.35,
            cycle=False,
            name='PolynomialDecay'
            )
    
    elif lr_schedule == 'piecewise':
        decay_step_0, decay_step_1, decay_step_2, decay_step_3 = int(10 * (dat.shape[0] / batch_size)), int(0.5 * T_max_steps), int(0.75 * T_max_steps), int(0.9 * T_max_steps)
        lr_decay0, lr_decay1, lr_decay2, lr_decay3 = np.float32(np.minimum(2 * init_lr, 0.5)), init_lr * lr_decay_fact, init_lr * tf.math.pow(lr_decay_fact,2.0), init_lr * tf.math.pow(lr_decay_fact,3.0)
        step = tf.Variable(0, trainable=False)
        if large_lr_start == True:
            if warmup:
                boundaries = [warmup_steps, decay_step_0, decay_step_1, decay_step_2] # , decay_step_3
                values =     [0.001, lr_decay0, init_lr, lr_decay1, lr_decay2] #, lr_decay3
            else:
                boundaries = [decay_step_0, decay_step_1, decay_step_2] #, decay_step_3 
                values =     [lr_decay0, init_lr, lr_decay1, lr_decay2] #, lr_decay3    
        else: 
            if warmup:
                boundaries = [warmup_steps, decay_step_1, decay_step_2] #, decay_step_3
                values =     [0.001, init_lr, lr_decay1, lr_decay2] #, lr_decay3
            else:
                boundaries = [decay_step_1, decay_step_2] #, decay_step_3
                values =     [init_lr, lr_decay1, lr_decay2] #, lr_decay3
        lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)
    else:
        raise ValueError("Invalid learning rate schedule configs.")

    # Optimizer selection
    if opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr_sched, momentum=momentum, nesterov=nesterov)
    elif opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_sched)

    return optimizer


# Function to format streamed results of HadamardCallback
def process_sparsity_callback(hist, hadamard_cb, lr_cb):
    # Format trajectory
    train_losses = np.array(hist.history['loss']).reshape(-1, 1)
    train_accs = np.array(hist.history['accuracy']).reshape(-1, 1)
    val_losses = np.array(hist.history['val_loss']).reshape(-1, 1)
    val_accs = np.array(hist.history['val_accuracy']).reshape(-1, 1)
    
    # Format learning rates
    learning_rates = np.array(lr_cb.lr_history[:, 1]).reshape(-1, 1)
    
    # Format sparsity callback results per layer + aggregated as df
    cb_total_res = np.array(hadamard_cb.total_metrics_data)
    merged_total = np.hstack([cb_total_res, train_losses, train_accs, val_losses, val_accs, learning_rates])
    cb_total_names = ['epoch', 'sparsity', 'cr', 'misalignment', 'l2', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr']
    res_total = pd.DataFrame(data=merged_total, columns=cb_total_names)
    res_total['epoch'] = res_total['epoch'].astype(int)
    res_total['sparsity'] = res_total['sparsity'].round(decimals=6)
    res_total['cr'] = res_total['cr'].round(decimals=4)
    res_total['misalignment'] = res_total['misalignment'].round(decimals=4)
    res_total['l2'] = res_total['l2'].round(decimals=4)
    res_total['train_loss'] = res_total['train_loss'].round(decimals=4)
    res_total['train_acc'] = res_total['train_acc'].round(decimals=4)
    res_total['val_loss'] = res_total['val_loss'].round(decimals=4)
    res_total['val_acc'] = res_total['val_acc'].round(decimals=4)
    res_total['lr'] = res_total['lr'].round(decimals=5)

    cb_epoch_res = np.array(hadamard_cb.metrics_data)
    cb_epoch_names = ['epoch', 'layer', 'worb', 'sparsity', 'cr', 'l1', 'l2', 'misalignment']
    res_epochs = pd.DataFrame(data=cb_epoch_res, columns=cb_epoch_names)
    res_epochs['epoch'] = res_epochs['epoch'].astype(int)
    res_epochs['layer'] = res_epochs['layer'].astype(int)
    res_epochs['worb'] = res_epochs['worb'].astype(int)
    res_epochs['sparsity'] = res_epochs['sparsity'].round(decimals=6)
    res_epochs['cr'] = res_epochs['cr'].round(decimals=4)
    res_epochs['l1'] = res_epochs['l1'].apply(lambda x: format(x, '.2e'))
    res_epochs['l2'] = res_epochs['l2'].apply(lambda x: format(x, '.2e'))
    res_epochs['misalignment'] = res_epochs['misalignment'].apply(lambda x: format(x, '.2e'))
    
    # Merge lr from res_total to res_epochs
    res_epochs_lr = pd.merge(res_epochs, res_total[['epoch', 'lr']], on='epoch', how='left')

    return res_total, res_epochs_lr

# Take output of process_sparsity_callback and creates plot of training metrics
def create_and_save_trajectory_plot(cb_total, run_path, max_loss=10, out_name='pre_trajectory.pdf', show=True):
    
    # Load data
    df = cb_total
    
     # Max train loss during final 70% of training
    total_epochs = df['epoch'].max() + 1
    start_epoch = int(total_epochs * 0.3)  # discard first training phase
    df_subsampled = df[df['epoch'] >= start_epoch]

    # Increase y-axis limit by 25% margin
    max_val_loss_subsampled = df_subsampled['val_loss'].max()
    modified_loss_ylim = max_val_loss_subsampled * 1.25
    if math.isnan(modified_loss_ylim) or math.isinf(modified_loss_ylim):
        modified_loss_ylim = 4
    
    # Define size params
    num_legend_entries=9
    legend_font_size = 24  
    tick_size = 22 
    label_size = 28
    linewidth = 3

    # Plotting Setup
    fig, ax1 = plt.subplots(figsize=(36, 20))

    # Epoch on x-axis
    ax1.set_xlabel('epoch', size=label_size)
    max_epoch = df['epoch'].max()
    ax1.set_xticks(range(0, max_epoch+1, 10))
    ax1.set_yticks([0.05 * i for i in range(0, 21)])
    
    # Add 90% accuracy mark
    ax1.axhline(y=0.9, color='grey', linestyle='--')

    # Plotting train_acc and val_acc on primary y-axis
    ax1.set_ylabel('accuracy, sparsity', color='black', size=label_size)
    line1, = ax1.plot(df['epoch'], df['train_acc'], color='xkcd:blue', label='train acc', linewidth=linewidth, alpha=0.7, linestyle='--')
    line2, = ax1.plot(df['epoch'], df['val_acc'], color='xkcd:azure', label='val acc', linewidth=linewidth, alpha=0.7)
    line0, = ax1.plot(df['epoch'], df['sparsity'], color='xkcd:turquoise', label='sparsity', linewidth=linewidth, alpha=0.5)
    ax1.tick_params(axis='y', labelcolor='black') 

    # 2nd y-axis for train_loss and val_loss
    ax2 = ax1.twinx()
    ax2.set_ylabel('train loss, val loss', color='black', size=label_size)
    line5, = ax2.plot(df['epoch'], df['train_loss'], color='xkcd:red', label='train loss', linewidth=linewidth, alpha=0.7, linestyle='--')
    line6, = ax2.plot(df['epoch'], df['val_loss'], color='xkcd:salmon', label='val loss', linewidth=linewidth, alpha=0.7)
    
    ax2.set_ylim(0, modified_loss_ylim)  # Limit range of loss vals
    ax2.tick_params(axis='y', labelcolor='black') 

    # Third y-axis for compression rate
    ax3 = ax1.twinx()
    ax3.spines['right'].set_position(('outward', 100))
    ax3.set_ylabel('compression rate', color='black', size=label_size) 
    line3, = ax3.plot(df['epoch'], df['cr'], color='xkcd:green', label='compression rate', linewidth=1.2 * linewidth)
    ax3.tick_params(axis='y', labelcolor='black')  

    # Fourth y-axis for misalignment
    ax4 = ax1.twinx()
    ax4.spines['right'].set_position(('outward', 195)) 
    ax4.set_ylabel('misalignment', color='black', size=label_size) 
    line4, = ax4.plot(df['epoch'], df['misalignment'], color='xkcd:fuchsia', label='misalignment', linewidth=1.2*linewidth)
    ax4.set_ylim(0, 10000)  # Limit range of misalignment values
    ax4.tick_params(axis='y', labelcolor='black') 

    # Fifth y-axis for learning rate
    ax5 = ax1.twinx()
    ax5.spines['right'].set_position(('outward', 320))
    ax5.set_ylabel('learning rate', color='black', size=label_size) 
    line7, = ax5.plot(df['epoch'], df['lr'], color='xkcd:gray', label='learning rate', linewidth=linewidth, alpha=0.6, linestyle='dashdot')
    ax5.tick_params(axis='y', labelcolor='black')
    max_lr = df['lr'].max()
    max_lr_rounded = math.ceil(max_lr / 0.05) * 0.05
    ax5.set_yticks([i for i in np.arange(0, max_lr_rounded + 0.05, 0.05)])
    
    # Sixth y-axis for model l2 norm
    ax6 = ax1.twinx()
    ax6.spines['right'].set_position(('outward', 420))
    ax6.set_ylabel('L2 norm', color='black', size=label_size) 
    line8, = ax6.plot(df['epoch'], df['l2'], color='xkcd:orange', label='L2 norm', linewidth=1.2 * linewidth)
    ax6.tick_params(axis='y', labelcolor='black')
    max_l2 = df['l2'].max()
    if math.isnan(df['l2'].max()) or math.isinf(df['l2'].max()):
        max_l2 = 10
    else:
        max_l2_rounded = math.ceil(df['l2'].max())

    # Tick sizes
    ax1.tick_params(axis='both', labelsize=tick_size)
    ax2.tick_params(axis='y', labelsize=tick_size)
    ax3.tick_params(axis='y', labelsize=tick_size)
    ax4.tick_params(axis='y', labelsize=tick_size)
    ax5.tick_params(axis='y', labelsize=tick_size)
    ax6.tick_params(axis='y', labelsize=tick_size)
    
    # Increase plot size at bottom
    plt.subplots_adjust(bottom=0.5) 

    # Create legend
    lines = [line7, line1, line2, line5, line6, line0, line3, line4, line8]
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, 
               loc='upper center', 
               bbox_to_anchor=(0.5, -0.07),  # Adjust as needed
               fancybox=True, framealpha=1,
               fontsize=legend_font_size,
               ncol=num_legend_entries) 

    plt.tight_layout()

    # Save plot
    output_file = os.path.join(run_path, out_name)
    plt.savefig(output_file, bbox_inches='tight')

    # Display plot
    if show:
        plt.show()

    print(f"Saving figure {out_name} was successful.")

def threshold_model_weights(model, threshold=100, mode="model"):
    """
    Modify the weights (and optionally biases) of specific layers in a model by thresholding.
    
    Args:
    - model: The model whose weights are to be modified.
    - threshold: The threshold value used for determining small weights.
    - mode: Determines the return type. "weights" returns the modified weights, "model" returns the model with modified weights.
    
    Returns:
    - Either a list of modified weights or the model with modified weights, depending on the 'mode' argument.
    
    This function modifies the weights (and biases, if applicable) of layers in the model that are instances of HadamardDense or HadamardConv2D.
    It computes the sparsity and compression rate of the modified layers. It returns either the modified weights or the modified model.
    """

    modified_weights_list = []
    reconstructed_weights_list = []
    overall_num_small_values = 0
    overall_total_elements = 0
    
    def _compute_sparsity_cr(array):
        num_small_values = np.sum(np.abs(array) < threshold)
        total_elements = np.prod(array.shape)
        sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
        compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0

        return num_small_values, total_elements, sparsity_ratio, compression_rate

    for layer_index, layer in enumerate(model.layers):
        if isinstance(layer, (HadamardDense, HadamardConv2D)):
            layer_weights = layer.get_weights()
            weight_factors = [w for w in layer_weights if len(w.shape) >= 2]

            # Compute reconstructed weights
            reconstructed_weight = np.ones_like(weight_factors[0])
            for w in weight_factors:
                reconstructed_weight *= w

            # Threshold weights
            small_weight_indices = np.abs(reconstructed_weight) < threshold
            weight_factors = [np.where(small_weight_indices, np.zeros_like(w), w) for w in weight_factors]

            # Define modified_weights with thresholded components
            modified_weights = weight_factors
            
            # Define reconstructed weights
            reconstructed_layer_weights = reconstructed_weight

            # Compute sparsity + compression rates for weights
            w_stats = _compute_sparsity_cr(reconstructed_weight)
            overall_num_small_values += w_stats[0]
            overall_total_elements += w_stats[1]

            if layer.use_bias:
                bias_factors = [w for w in layer_weights if len(w.shape) == 1]

                if layer.factorize_bias and len(bias_factors) > 0:
                    # Compute reconstructed bias
                    reconstructed_bias = np.ones_like(bias_factors[0])
                    for b in bias_factors:
                        reconstructed_bias *= b

                    small_bias_indices = np.abs(reconstructed_bias) < threshold
                    bias_factors = [np.where(small_bias_indices, np.zeros_like(b), b) for b in bias_factors]

                    # Append thresholded bias factors to modified_weights
                    modified_weights += bias_factors
                    
                    # Append reconstructed bias to reconstructed layer weights
                    reconstructed_layer_weights += reconstructed_bias

                    # Compute sparsity + compression rates for bias
                    b_stats = _compute_sparsity_cr(reconstructed_bias)
                    overall_num_small_values += b_stats[0]
                    overall_total_elements += b_stats[1]

                elif len(bias_factors) == 1:
                    # If bias not factorized, append original bias to modified_weights and reconstructed layer weights
                    modified_weights += bias_factors
                    reconstructed_layer_weights += bias_factors

            # Append modified_weights of current layer to modified_weights_list
            modified_weights_list.append(modified_weights)
            
            # Append reconstructed layer weights to reconstructed_weights_list
            reconstructed_weights_list.append(reconstructed_layer_weights)

        else:
            # For layers that are not Hadamard layers, append original weights
            modified_weights_list.append(layer.get_weights())
            #reconstructed_weights_list.append(layer.get_weights())

    # Compute overall metrics
    overall_sparsity = overall_num_small_values / overall_total_elements if overall_total_elements > 0 else float('inf')
    overall_compression_rate = overall_total_elements / (overall_total_elements - overall_num_small_values) if overall_num_small_values < overall_total_elements else 0

    print(f"Overall model sparsity: {overall_sparsity * 100:.4f}%")
    print(f"Overall model CR: {overall_compression_rate:.2f}")

    if mode == "model":
        for layer, new_weights in zip(model.layers, modified_weights_list):
            layer.set_weights(new_weights)
        return model, reconstructed_weights_list, overall_compression_rate
    elif mode == "weights":
        return modified_weights_list, reconstructed_weights_list, overall_compression_rate
    else:
        raise ValueError("Invalid mode. Choose 'model' or 'weights'.")
