
# wideresnet inspiration from here: https://github.com/zytx121/keras_ensemble_cifar10/blob/master/1.Models/1_Wide_ResNet.py

# imports
import math
import numpy as np
import tensorflow as tf

class TwiceTruncatedNormalInitializer(tf.keras.initializers.Initializer):
    def __init__(self, minprod, depth, n_input=None, gain=np.sqrt(2), std=None):
        self.minprod = minprod #if minprod is not None else 3e-3
        self.depth = depth
        self.gain = gain  # Set gain to sqrt(2)
        self.std = std
        self.n_input = n_input

    def __call__(self, shape, dtype=None):
        # Compute the fan_in
        #fan_in = np.prod(shape[:-1])
        std = self.std if self.std is not None else np.sqrt(np.square(self.gain) / self.n_input)

        threshold = self.minprod ** (1 / self.depth)
        threshold_max = np.minimum(1, 2 * (self.gain / np.sqrt(self.n_input)) ** (1 / self.depth))

        matrix = np.random.normal(0, std, size=shape)

        while np.any(np.abs(matrix) < threshold) or np.any(np.abs(matrix) > threshold_max):
            too_small = np.abs(matrix) < threshold
            too_large = np.abs(matrix) > threshold_max
            redraw_indices = too_small | too_large  # Combine conditions for redrawing
            matrix[redraw_indices] = np.random.normal(0, std, size=np.sum(redraw_indices))

        return tf.convert_to_tensor(matrix, dtype=dtype)

    def get_config(self):  # To support serialization
        return {"minprod": self.minprod, "depth": self.depth, "std": self.std}
    
def hadamard_initializer(depth, n_input, initialization):
    """
    Custom initializer for Hadamard factor weight matrices with HeNormal or HeUniform initialization.

    Args:
    - depth: Number of factor weight matrices.
    - n_input: Number of input units.
    - initialization: A tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform, tf.keras.initializers.LecunNormal or tf.keras.initializers.LecunUniform object.

    Returns:
    - An initializer function that can be used in the HadamardDense layer.
    """
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_input <= 0:
        raise ValueError("n_input must be a positive integer.")

    if not isinstance(initialization, (tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform,\
                                       tf.keras.initializers.LecunUniform, tf.keras.initializers.LecunNormal,\
                                       tf.keras.initializers.Orthogonal, tf.keras.initializers.VarianceScaling,\
                                       TwiceTruncatedNormalInitializer)):
        raise ValueError("Invalid initialization type. Use implemented tensorflow initializer object.")

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        # Calculate the standard deviation based HeNormal stdev of reconstructed weight
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        # Calculate the limit for the uniform distribution based on HeUniform of reconstructed weight
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        # Calculate the standard deviation based LecunNormal stdev of reconstructed weight
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        # Calculate the limit for the uniform distribution based on LecunUniform of reconstructed weight
        limit = math.sqrt(3) * ((1.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        # Simply return orthogonal init
        return tf.keras.initializers.Orthogonal()
    elif isinstance(initialization, tf.keras.initializers.VarianceScaling):
        # initializer used in the spred paper Unif[-1/sqrt(fan_in),1/sqrt(fan_in)]
        return tf.keras.initializers.VarianceScaling(scale=1/3, mode="fan_in", 
                                                     distribution="uniform")
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        # Calculate equivar standard deviation based on HeNormal
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use twice truncated normal distribution with mean 0 and equivar standard deviation
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)


# Custom initializer function for convolutional Hadamard layers
def hadamard_initializer_conv2d(depth, kernel_size, n_channels, initialization):
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_channels <= 0:
        raise ValueError("n_channels must be a positive integer.")

    kernel_height, kernel_width = kernel_size
    n_input = kernel_height * kernel_width * n_channels

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        limit = math.sqrt(3) * ((3.0 / n_input) ** (1.0 / (2 * depth)))
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        return tf.keras.initializers.Orthogonal()
    elif isinstance(initialization, tf.keras.initializers.VarianceScaling):
        # initializer used in the spred paper Unif[-1/sqrt(fan_in),1/sqrt(fan_in)]
        return tf.keras.initializers.VarianceScaling(scale=1/3, mode="fan_in", 
                                                     distribution="uniform")
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)
        
    else:
        raise ValueError("Invalid initialization type.")
        
        


class HadamardDense(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=1, init=None, init_rest=None, init_type="ones", 
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDense, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeUniform()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)

        if self.init_type not in ['ones', 'equivar', 'root', 'spred']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'spred']")
            
    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_factor = hadamard_initializer(self.depth, input_shape[1], self.init)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=hadamard_initializer(self.depth, input_shape[1], self.init),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "spred":
            self.init_factor = hadamard_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=hadamard_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling()),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = hadamard_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=hadamard_initializer(self.depth, self.units, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "spred":
                    self.init_bias_factor = hadamard_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=hadamard_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling()),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1
            for bias in self.other_biases:
                B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'init_type': self.init_type,
            'la': self.la,
            'depth': self.depth,
            'factorize_bias': self.factorize_bias
        })
        return config  

# Custom Conv2D layer with Hadamard factorization
class HadamardConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=None, activation='linear', la=0, depth=1,
                 init=None, init_rest=None, init_type="ones", use_bias=True,
                 factorize_bias=True, padding='valid', **kwargs):
        super(HadamardConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth
        self.strides = strides if strides is not None else [1,1,1,1]
        self.init = init if init is not None else tf.keras.initializers.HeUniform()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)
        self.padding = 'SAME' if padding == 'same' else 'VALID' #padding

        if self.init_type not in ['ones', 'equivar', 'root', 'spred']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'spred']")

    def build(self, input_shape):
        n_channels = input_shape[-1]
        kernel_shape = (self.kernel_size[0], self.kernel_size[1], n_channels, self.filters)
        

        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_factor = hadamard_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=hadamard_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "spred":
            self.init_factor = hadamard_initializer_conv2d(self.depth, self.kernel_size, n_channels, tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=hadamard_initializer_conv2d(self.depth, self.kernel_size, n_channels, tf.keras.initializers.VarianceScaling()),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_kernel = self.init(shape=kernel_shape)
            init_signs_w = tf.math.sign(init_kernel)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_kernel), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                bias_shape = (self.filters,)
                if self.init_type == "ones":
                        bias_init_config = self.init.get_config()
                        self.bias_init = self.init.__class__(**bias_init_config)
                        self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.bias_init, regularizer=self.reg, trainable=True)
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = hadamard_initializer(self.depth, self.filters, self.init)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=hadamard_initializer(self.depth, self.filters, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "spred":
                    self.init_bias_factor = hadamard_initializer(self.depth, self.filters, tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=hadamard_initializer(self.depth, self.filters, tf.keras.initializers.VarianceScaling()),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=bias_shape)
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)  
            else:
                self.bias = self.add_weight(name='bias', shape=(self.filters,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        output = tf.nn.conv2d(inputs, W_reconstructed, strides=self.strides, padding=self.padding)

        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed = B_reconstructed * bias
                output = tf.nn.bias_add(output, B_reconstructed)
            else:
                output = tf.nn.bias_add(output, self.bias)

        return self.activation(output)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'init_type': self.init_type,
            'la': self.la,
            'depth': self.depth,
            'factorize_bias': self.factorize_bias
        })
        return config

# Keras callback to printing sparsity metrics for Hadamard layers
class HadamardCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold=np.finfo(np.float32).eps, save_metrics=False, verbose=2):
        super(HadamardCallback, self).__init__()
        self.threshold = threshold
        self.overall_num_small_values = 0
        self.overall_total_elements = 0
        self.save_metrics = save_metrics
        self.metrics_data = np.empty((0, 7))
        self.total_metrics_data = np.empty((0,3))
        self.verbose = verbose

        if self.verbose not in [0,1,2]:
            raise ValueError("verbose must be one of 0,1,2")

    def on_epoch_end(self, epoch, logs=None):
        model = self.model
        layers = model.layers
        if self.verbose != 0:
          print(f"\n###\nEpoch {epoch+1}:")

        self.overall_num_small_values = 0
        self.overall_total_elements = 0

        for layer_index, layer in enumerate(layers):
            if isinstance(layer, (HadamardDense, HadamardConv2D)):
                layer_weights = layer.get_weights()
                weight_factors = [w for w in layer_weights if len(w.shape) >= 2]
                reconstructed_matrix = np.ones_like(weight_factors[0])
                for w in weight_factors:
                    reconstructed_matrix *= w

                num_small_values_w, total_elements_w, sparsity_ratio_w, compression_rate_w, L1_w, L2_w = self._compute_sparsity_compression_norms(reconstructed_matrix)
                self.overall_num_small_values += num_small_values_w
                self.overall_total_elements += total_elements_w

                compression_rate_w_str = "{:.2f}".format(compression_rate_w) if compression_rate_w != 0 else 'NA'

                if self.save_metrics:  # Conditionally save based on save_metrics
                    # 1 = weight object
                    self.metrics_data = np.vstack((self.metrics_data, [epoch, layer_index, 1, sparsity_ratio_w, compression_rate_w, L1_w, L2_w]))

                if layer.use_bias and layer.factorize_bias:
                    bias_factors = [w for w in layer_weights if len(w.shape) == 1]
                    reconstructed_bias = np.ones_like(bias_factors[0])
                    for b in bias_factors:
                        reconstructed_bias *= b

                    num_small_values_b, total_elements_b, sparsity_ratio_b, compression_rate_b, L1_b, L2_b = self._compute_sparsity_compression_norms(reconstructed_bias)
                    self.overall_num_small_values += num_small_values_b
                    self.overall_total_elements += total_elements_b

                    if self.save_metrics:  # Conditionally save based on save_metrics
                        # 0 = bias object
                        self.metrics_data = np.vstack((self.metrics_data, [epoch, layer_index, 0, sparsity_ratio_b, compression_rate_b, L1_b, L2_b]))

                    total_elements = total_elements_w + total_elements_b
                    num_small_values = num_small_values_w + num_small_values_b
                    sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
                    compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0
                    compression_rate_str = "{:.2f}".format(compression_rate) if compression_rate != 0 else 'NA'
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = {sparsity_ratio_b * 100:.2f}%, joint = {sparsity_ratio * 100:.2f}%, CR = {compression_rate_str}")

                else:
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = NA, joint = {sparsity_ratio_w * 100:.2f}%, CR = {compression_rate_w_str}\n")


        # Compute and output overall metrics
        overall_sparsity, overall_compression_rate = self._compute_overall_metrics(self.overall_num_small_values, self.overall_total_elements)

        if self.save_metrics:  # Conditionally save based on save_metrics
                    self.total_metrics_data = np.vstack((self.total_metrics_data, [epoch, overall_sparsity, overall_compression_rate]))

        overall_compression_rate_str = "{:.2f}".format(overall_compression_rate) if overall_compression_rate != 0 else 'NA'
        if self.verbose != 0:
          print(f"Total sparsity = {overall_sparsity * 100:.2f}%, Total Compression rate = {overall_compression_rate_str}\n###")

    def _compute_sparsity_compression_norms(self, reconstructed_array):
        num_small_values = np.sum(np.abs(reconstructed_array) < self.threshold)
        total_elements = np.prod(reconstructed_array.shape)
        sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
        compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0

        # Compute L1 and L2 Norms
        L1_norm = np.sum(np.abs(reconstructed_array))
        L2_norm = np.sqrt(np.sum(np.square(reconstructed_array)))

        return num_small_values, total_elements, sparsity_ratio, compression_rate, L1_norm, L2_norm

    def _compute_overall_metrics(self, overall_num_small_values, overall_total_elements):
        overall_sparsity = overall_num_small_values / overall_total_elements if overall_total_elements > 0 else float('inf')
        overall_compression_rate = overall_total_elements / (overall_total_elements - overall_num_small_values) if overall_num_small_values < overall_total_elements else 0
        return overall_sparsity, overall_compression_rate

# Color normalization pre-processing function
def color_preprocessing(X_train,X_test,X_val):
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_val = X_val.astype('float32')
    mean = [125.3, 123.0, 113.9]
    std  = [63.0,  62.1,  66.7]
    for i in range(3):
        X_train[:,:,:,i] = (X_train[:,:,:,i] - mean[i]) / std[i]
        X_test[:,:,:,i] = (X_test[:,:,:,i] - mean[i]) / std[i]
        X_val[:,:,:,i] = (X_val[:,:,:,i] - mean[i]) / std[i]

    return X_train, X_test, X_val

# ResNet-18 implementation
# ResNet-34 has block_layers=[3, 4, 6, 3]

def identity_block(x, filter, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
    # copy tensor to variable called x_skip
    x_skip = x
    # Layer 1
    #x = tf.keras.layers.Conv2D(filter, (3,3), padding = 'same')(x)
    x = HadamardConv2D(filters=filter, kernel_size=(3, 3), depth=depth, init=init, init_type=init_type, padding='same', la=la, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization(axis=3)(x)
    x = tf.keras.layers.Activation('relu')(x)
    # Layer 2
    #x = tf.keras.layers.Conv2D(filter, (3,3), padding = 'same')(x)
    x = HadamardConv2D(filters=filter, kernel_size=(3, 3), depth=depth, init=init, init_type=init_type, padding='same', la=la, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization(axis=3)(x)
    # Add resid
    x = tf.keras.layers.Add()([x, x_skip])
    x = tf.keras.layers.Activation('relu')(x)
    return x

def convolutional_block(x, filter, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
    # copy tensor to variable called x_skip
    x_skip = x
    # Layer 1
    #x = tf.keras.layers.Conv2D(filter, (3,3), padding = 'same', strides = (2,2))(x)
    x = HadamardConv2D(filters=filter, kernel_size=(3, 3), depth=depth, init=init, init_type=init_type, 
                       strides=2, padding='same', la=la, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization(axis=3)(x)
    x = tf.keras.layers.Activation('relu')(x)
    # Layer 2
    #x = tf.keras.layers.Conv2D(filter, (3,3), padding = 'same')(x)
    x = HadamardConv2D(filters=filter, kernel_size=(3, 3), depth=depth, init=init, init_type=init_type, 
                       strides=1, padding='same', la=la, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization(axis=3)(x)
    # Processing Residue with conv(1,1)
    #x_skip = tf.keras.layers.Conv2D(filter, (1,1), strides = (2,2))(x_skip)
    x_skip = HadamardConv2D(filters=filter, kernel_size=(1,1), strides=2, padding='same', 
                            depth=depth, init=init, init_type=init_type, la=la, use_bias=False)(x_skip)
    x_skip = tf.keras.layers.BatchNormalization()(x_skip)
    # Add Residue
    x = tf.keras.layers.Add()([x, x_skip])
    x = tf.keras.layers.Activation('relu')(x)
    return x

def ResNet_18(shape = (32, 32, 3), classes = 10, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0,
              use_bias=True, factorize_bias=True):
    # Input
    x_input = tf.keras.layers.Input(shape)
    #x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
    # First Conv layer
    #x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(x)
    x = HadamardConv2D(filters=64, kernel_size=(3, 3), depth=depth, init=init, init_type=init_type, strides=1, 
                       padding='same', la=la, use_bias=False)(x_input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2, padding='same')(x)
    # Define size of sub-blocks and initial filter size
    block_layers = [2, 2, 2, 2]
    filter_size = 64
    # Step 3 Add the Resnet Blocks
    for i in range(4):
        if i == 0:
            # For sub-block 1 Residual/Convolutional block not needed
            for j in range(block_layers[i]):
                x = identity_block(x, filter_size, depth=depth, init_type=init_type, init = init, la=la)
        else:
            # One Residual/Convolutional Block followed by Identity blocks
            # The filter size will go on increasing by a factor of 2
            filter_size = filter_size*2
            x = convolutional_block(x, filter_size, depth=depth, init_type=init_type, init = init, la=la)
            for j in range(block_layers[i] - 1):
                x = identity_block(x, filter_size, depth=depth, init_type=init_type, init = init, la=la)
    # Step 4 End Dense Network
    x = tf.keras.layers.AveragePooling2D((2,2), padding = 'same')(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(512, activation = 'relu')(x)
    #x = HadamardDense(units=512, activation='relu', depth=depth, la=la, init_type=init_type, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    #x = tf.keras.layers.Dense(classes, activation = 'softmax')(x)
    x = HadamardDense(units=classes, activation='softmax', depth=depth, init=init, la=la, 
                      init_type=init_type, use_bias=use_bias, factorize_bias=factorize_bias)(x)
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "ResNet18")
    return model

# Hadamard WideResNet function (WRN-16-8 has dep=16 k=8)

def WideResNet(dep=16, k=8, shape = (32, 32, 3), classes = 10, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0,
              use_bias=True, factorize_bias=True):
    print('Wide-Resnet %dx%d' %(dep, k))
    n_filters  = [16, 16*k, 32*k, 64*k]
    n_stack    = (dep - 4) // 6

    def conv3x3(x,filters, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        return HadamardConv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', use_bias=False,
                              depth=depth, init_type=init_type, init=init, la=la)(x)
        #return Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1), padding='same',
        #kernel_initializer='he_normal',
        #kernel_regularizer=l2(WEIGHT_DECAY),
        #use_bias=False)(x)

    def bn_relu(x):
        x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        x = tf.keras.layers.Activation('relu')(x)
        return x

    def residual_block(x, out_filters, increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        stride = 1
        if increase:
            stride = 2

        o1 = bn_relu(x)

        conv_1 = HadamardConv2D(filters=out_filters, kernel_size=(3,3), strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)

        #conv_1 = Conv2D(out_filters,kernel_size=(3,3),strides=stride,padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)

        o2 = bn_relu(conv_1)

        conv_2 = HadamardConv2D(filters=out_filters,kernel_size=(3,3),strides=1,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o2)

        #conv_2 = Conv2D(out_filters,kernel_size=(3,3), strides=(1,1), padding='same',kernel_initializer='he_normal',
        #    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o2)
        if increase or IN_FILTERS != out_filters:
            proj = HadamardConv2D(filters=out_filters,kernel_size=(1,1),strides=stride,padding='same',use_bias=False,
                                depth=depth,init_type=init_type,init=init,la=la)(o1)
            #proj = Conv2D(out_filters,kernel_size=(1,1),strides=stride,padding='same',kernel_initializer='he_normal',
            #                    kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(o1)
            block = tf.keras.layers.add([conv_2, proj])
        else:
            block = tf.keras.layers.add([conv_2,x])
        return block

    def wide_residual_layer(x,out_filters,increase=False, depth=1, init_type='equivar', init = tf.keras.initializers.HeUniform(), la=0):
        global IN_FILTERS
        x = residual_block(x,out_filters,increase, depth=depth, init_type=init_type, init = init, la=la)
        IN_FILTERS = out_filters
        for _ in range(1,int(n_stack)):
            x = residual_block(x,out_filters, depth=depth, init_type=init_type, init = init, la=la)
        return x

    x_input = tf.keras.layers.Input(shape)
    x = conv3x3(x_input,n_filters[0], depth=depth, init_type=init_type, init = init, la=la) #img_input
    x = wide_residual_layer(x,n_filters[1], depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[2], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = wide_residual_layer(x,n_filters[3], increase=True, depth=depth, init_type=init_type, init = init, la=la)
    x = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.AveragePooling2D((8,8))(x)
    x = tf.keras.layers.Flatten()(x)
    x = HadamardDense(units=classes,activation='softmax',depth=depth,use_bias=use_bias,factorize_bias=factorize_bias,
                      init_type=init_type,init=init,la=la)(x)
    #x = Dense(classes_num,activation='softmax',kernel_initializer='he_normal',kernel_regularizer=l2(WEIGHT_DECAY),use_bias=False)(x)
        
    # define model
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "WideResNet")
    return model