# imports
import math
import numpy as np
import tensorflow as tf
from layers import HadamardDense, HadamardConv2D, SparseConv2D
from layers import StrHadamardDense, StrHadamardDenseV2, StrConv2D

# Keras callback for printing sparsity metrics for Hadamard layers
class HadamardCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold=np.finfo(np.float32).eps, save_metrics=False, verbose=2):
        super(HadamardCallback, self).__init__()
        self.threshold = threshold
        self.overall_num_small_values = 0
        self.overall_total_elements = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0
        self.overall_l2_sq_factors = 0
        self.save_metrics = save_metrics
        self.metrics_data = np.empty((0, 9))
        self.total_metrics_data = np.empty((0,6))
        self.verbose = verbose

        if self.verbose not in [0,1,2]:
            raise ValueError("verbose must be one of 0,1,2")

    def on_epoch_end(self, epoch, logs=None):
        model = self.model
        layers = model.layers
        if self.verbose != 0:
          print(f"\n###\nEpoch {epoch+1}:")

        self.overall_num_small_values = 0
        self.overall_total_elements = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0
        self.overall_l2_sq_factors = 0

        for layer_index, layer in enumerate(layers):
            if isinstance(layer, (HadamardDense, HadamardConv2D, SparseConv2D)):
                layer_weights = layer.get_weights()
                weight_factors = [w for w in layer_weights if len(w.shape) >= 2]
                depth = len(weight_factors)
                reconstructed_matrix = np.ones_like(weight_factors[0])
                for w in weight_factors:
                    reconstructed_matrix *= w

                num_small_values_w, total_elements_w, sparsity_ratio_w, compression_rate_w, L1_w, L2_w = self._compute_sparsity_compression_norms(reconstructed_matrix)
                quasinorm_exponent = 2 / depth
                min_weight_penalty = depth * np.sum(np.power(np.abs(reconstructed_matrix), quasinorm_exponent))
                weight_l2_penalty = 0
                for w in weight_factors:
                    weight_l2_penalty += np.sum(np.square(w))
                misalignment_w = weight_l2_penalty - min_weight_penalty
                self.overall_num_small_values += num_small_values_w
                self.overall_total_elements += total_elements_w
                self.overall_misalignment += misalignment_w
                self.overall_l2_sq += np.square(L2_w)
                self.overall_l2_sq_factors += weight_l2_penalty

                compression_rate_w_str = "{:.2f}".format(compression_rate_w) if compression_rate_w != 0 else 'NA'

                if self.save_metrics:  # Conditionally save based on save_metrics
                    # 1 = weight object
                    if isinstance(layer, (SparseConv2D)):
                        misalignment_w = -1
                    #weight_l2_penalty = 0
                    self.metrics_data = np.vstack((self.metrics_data, [epoch+1, layer_index, 1, sparsity_ratio_w, compression_rate_w, L1_w, L2_w, misalignment_w, weight_l2_penalty]))

                if layer.use_bias and layer.factorize_bias:
                    bias_factors = [w for w in layer_weights if len(w.shape) == 1]
                    reconstructed_bias = np.ones_like(bias_factors[0])
                    for b in bias_factors:
                        reconstructed_bias *= b

                    num_small_values_b, total_elements_b, sparsity_ratio_b, compression_rate_b, L1_b, L2_b = self._compute_sparsity_compression_norms(reconstructed_bias)
                    min_bias_penalty = depth * np.sum(np.power(np.abs(reconstructed_bias), quasinorm_exponent))
                    bias_l2_penalty = 0
                    for b in bias_factors:
                        bias_l2_penalty += np.sum(np.square(b))
                    misalignment_b = bias_l2_penalty - min_bias_penalty
                    self.overall_num_small_values += num_small_values_b
                    self.overall_total_elements += total_elements_b
                    self.overall_misalignment += misalignment_b
                    self.overall_l2_sq += np.square(L2_b)
                    self.overall_l2_sq_factors += bias_l2_penalty

                    if self.save_metrics:  # Conditionally save based on save_metrics
                        # 0 = bias object
                        if isinstance(layer, (SparseConv2D)):
                            bias_l2_penalty = 0
                            misalignment_b = 0
                        self.metrics_data = np.vstack((self.metrics_data, [epoch+1, layer_index, 0, sparsity_ratio_b, compression_rate_b, L1_b, L2_b, misalignment_b, bias_l2_penalty]))

                    total_elements = total_elements_w + total_elements_b
                    num_small_values = num_small_values_w + num_small_values_b
                    sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
                    compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0
                    compression_rate_str = "{:.2f}".format(compression_rate) if compression_rate != 0 else 'NA'
                    total_misalignment = misalignment_w + misalignment_b
                    total_l2_penalty_factors = weight_l2_penalty + bias_l2_penalty
                    if isinstance(layer, (SparseConv2D)):
                        total_misalignment = -1
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = {sparsity_ratio_b * 100:.2f}%, joint = {sparsity_ratio * 100:.2f}%, CR = {compression_rate_str}, misalignment weights = {misalignment_w:.3e}, biases = {misalignment_b:.3e}, joint = {total_misalignment:.3e}, l2 pen factors = {total_l2_penalty_factors:.3e}\n")

                else:
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = NA, joint = {sparsity_ratio_w * 100:.2f}%, CR = {compression_rate_w_str}, misalignment weights = {misalignment_w:.3e}, biases = NA, l2 pen factors = {weight_l2_penalty:.3e}\n")


        # Compute and output overall metrics
        overall_sparsity, overall_compression_rate = self._compute_overall_metrics(self.overall_num_small_values, self.overall_total_elements)
        if isinstance(layer, (SparseConv2D)):
            self.overall_misalignment = -1
        overall_misalignment = self.overall_misalignment
        overall_l2 = np.sqrt(self.overall_l2_sq)
        overall_l2_sq_factors = self.overall_l2_sq_factors

        if self.save_metrics:  # Conditionally save based on save_metrics
                    self.total_metrics_data = np.vstack((self.total_metrics_data, [epoch+1, overall_sparsity, overall_compression_rate, overall_misalignment, overall_l2, overall_l2_sq_factors]))

        overall_compression_rate_str = "{:.2f}".format(overall_compression_rate) if overall_compression_rate != 0 else 'NA'
        if self.verbose != 0:
          print(f"Total sparsity = {overall_sparsity * 100:.2f}%, Total Compression rate = {overall_compression_rate_str}, Total L2 norm = {overall_l2:.3e}, Total misalignment = {overall_misalignment:.3e}, Total sq. L2 factors = {overall_l2_sq_factors:.3e}\n###")

    def _compute_sparsity_compression_norms(self, reconstructed_array):
        num_small_values = np.sum(np.abs(reconstructed_array) < self.threshold)
        total_elements = np.prod(reconstructed_array.shape)
        sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
        compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0

        # Compute L1 and L2 Norms
        L1_norm = np.sum(np.abs(reconstructed_array))
        L2_norm = np.sqrt(np.sum(np.square(reconstructed_array)))
        
        return num_small_values, total_elements, sparsity_ratio, compression_rate, L1_norm, L2_norm

    def _compute_overall_metrics(self, overall_num_small_values, overall_total_elements):
        overall_sparsity = overall_num_small_values / overall_total_elements if overall_total_elements > 0 else float('inf')
        overall_compression_rate = overall_total_elements / (overall_total_elements - overall_num_small_values) if overall_num_small_values < overall_total_elements else 0
        return overall_sparsity, overall_compression_rate
    
# Extended Callback for Structured Sparsity including SparseConv2D
class StructuredSparsityCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold=np.finfo(np.float32).eps, save_metrics=False, verbose=2):
        super(StructuredSparsityCallback, self).__init__()
        self.threshold = threshold
        self.overall_num_small_groups = 0
        self.overall_total_groups = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0
        self.overall_l2_sq_factors = 0
        self.overall_min_penalty = 0.0
        self.save_metrics = save_metrics
        self.metrics_data = np.empty((0, 10))
        self.total_metrics_data = np.empty((0, 7))
        self.verbose = verbose
        if self.verbose not in [0, 1, 2]:
            raise ValueError("verbose must be one of 0,1,2")

    def _compute_group_sparsity_metrics(self, reconstructed, group_axis, depth):
        perm = list(range(reconstructed.ndim))
        perm[0], perm[group_axis] = perm[group_axis], perm[0]
        r_perm = np.transpose(reconstructed, perm)
        group_norms = np.sqrt(np.sum(np.square(r_perm), axis=tuple(range(1, r_perm.ndim))))
        num_small_groups = np.sum(group_norms < self.threshold)
        total_groups = group_norms.size
        sparsity_ratio = num_small_groups / total_groups if total_groups > 0 else float('inf')
        compression_rate = total_groups / (total_groups - num_small_groups) if num_small_groups < total_groups else 0
        L1_norm = np.sum(np.abs(reconstructed))
        L2_norm = np.sqrt(np.sum(np.square(reconstructed)))
        return num_small_groups, total_groups, sparsity_ratio, compression_rate, L1_norm, L2_norm, group_norms

    def on_epoch_end(self, epoch, logs=None):
        model = self.model
        layers = model.layers
        if self.verbose != 0:
            print(f"\n###\nEpoch {epoch+1}:")
        self.overall_num_small_groups = 0
        self.overall_total_groups = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0
        self.overall_l2_sq_factors = 0
        self.overall_min_penalty = 0.0

        for layer_index, layer in enumerate(layers):
            if hasattr(layer, 'depth') and hasattr(layer, 'get_weights'):
                if layer.__class__.__name__ == "SparseConv2D":
                    layer_weights = layer.get_weights()
                    kernel = layer_weights[0]
                    multfac = layer_weights[1]
                    depth_kernel = layer.depth
                    kshape = kernel.shape
                    pos = layer.position_sparsity if layer.position_sparsity >= 0 else len(kshape) + layer.position_sparsity
                    group_axis = pos
                    reconstructed_kernel = kernel * (np.abs(multfac) ** (depth_kernel - 1))
                    (num_small_groups, total_groups, sparsity_ratio,
                     compression_rate, L1_kernel, L2_kernel, group_norms) = self._compute_group_sparsity_metrics(
                        reconstructed_kernel, group_axis, depth_kernel)
                    quasinorm_exponent = 2 / depth_kernel
                    min_weight_penalty = depth_kernel * np.sum(np.power(group_norms, quasinorm_exponent))
                    weight_l2_penalty = np.sum(np.square(kernel)) + (depth_kernel - 1) * np.sum(np.square(multfac))
                    misalignment_kernel = weight_l2_penalty - min_weight_penalty

                    self.overall_num_small_groups += num_small_groups
                    self.overall_total_groups += total_groups
                    self.overall_misalignment += misalignment_kernel
                    self.overall_l2_sq += np.square(L2_kernel)
                    self.overall_l2_sq_factors += weight_l2_penalty
                    self.overall_min_penalty += min_weight_penalty

                    if self.save_metrics:
                        self.metrics_data = np.vstack((self.metrics_data,
                            [epoch+1, layer_index, 1, sparsity_ratio, compression_rate,
                             L1_kernel, L2_kernel, misalignment_kernel, weight_l2_penalty, min_weight_penalty]))
                    
                    if layer.use_bias and layer.factorize_bias:
                        num_bias = layer.depth if not isinstance(layer, StrHadamardDense) else 1
                        bias_factors = layer_weights[2:2+num_bias]
                        depth_bias = num_bias
                        reconstructed_bias = bias_factors[0].copy()
                        group_axis_bias = 0
                        for b in bias_factors[1:]:
                            reconstructed_bias *= b
                        (num_small_groups_b, total_groups_b, sparsity_ratio_b,
                         compression_rate_b, L1_bias, L2_bias, group_norms_bias) = self._compute_group_sparsity_metrics(
                            reconstructed_bias, group_axis_bias, depth_bias)
                        quasinorm_exponent_bias = 2 / depth_bias
                        min_bias_penalty = depth_bias * np.sum(np.power(group_norms_bias, quasinorm_exponent_bias))
                        bias_l2_penalty = sum(np.sum(np.square(b)) for b in bias_factors)
                        misalignment_bias = bias_l2_penalty - min_bias_penalty
                        self.overall_num_small_groups += num_small_groups_b
                        self.overall_total_groups += total_groups_b
                        self.overall_misalignment += misalignment_bias
                        self.overall_l2_sq += np.square(L2_bias)
                        self.overall_l2_sq_factors += bias_l2_penalty
                        self.overall_min_penalty += min_bias_penalty
                        if self.save_metrics:
                            self.metrics_data = np.vstack((self.metrics_data,
                                [epoch+1, layer_index, 0, sparsity_ratio_b, compression_rate_b,
                                 L1_bias, L2_bias, misalignment_bias, bias_l2_penalty, min_bias_penalty]))
                        total_groups_layer = total_groups + total_groups_b
                        num_small_total = num_small_groups + num_small_groups_b
                        joint_sparsity = num_small_total / total_groups_layer if total_groups_layer > 0 else float('inf')
                        total_misalignment = misalignment_kernel + misalignment_bias
                        total_l2_penalty_factors = weight_l2_penalty + bias_l2_penalty
                        if self.verbose == 2:
                            print(f"{layer.name}: sparsity kernel = {sparsity_ratio*100:.2f}%, min penalty kernel = {min_weight_penalty:.3e}, "
                                  f"bias = {sparsity_ratio_b*100:.2f}%, min penalty bias = {min_bias_penalty:.3e}, joint = {joint_sparsity*100:.2f}%, "
                                  f"CR = {compression_rate:.2f}, misalignment kernel = {misalignment_kernel:.3e}, misalignment bias = {misalignment_bias:.3e}, "
                                  f"joint misalignment = {total_misalignment:.3e}, l2 pen factors = {total_l2_penalty_factors:.3e}\n")
                    else:
                        if self.verbose == 2:
                            print(f"{layer.name}: sparsity kernel = {sparsity_ratio*100:.2f}%, min penalty kernel = {min_weight_penalty:.3e}, "
                                  f"bias = NA, min penalty bias = NA, joint = {sparsity_ratio*100:.2f}%, CR = {compression_rate:.2f}, "
                                  f"misalignment kernel = {misalignment_kernel:.3e}, bias = NA, l2 pen factors = {weight_l2_penalty:.3e}\n")

                elif isinstance(layer, (StrHadamardDense, StrHadamardDenseV2, StrConv2D)):
                    layer_weights = layer.get_weights()
                    depth_kernel = layer.depth
                    kernel_factors = layer_weights[:depth_kernel]
                    if isinstance(layer, StrHadamardDense):
                        group_axis = 1
                    elif isinstance(layer, StrHadamardDenseV2):
                        group_axis = 0
                    elif isinstance(layer, StrConv2D):
                        kshape = kernel_factors[0].shape
                        pos = layer.position_sparsity if layer.position_sparsity >= 0 else len(kshape) + layer.position_sparsity
                        group_axis = pos

                    reconstructed_kernel = kernel_factors[0].copy()
                    for w in kernel_factors[1:]:
                        shape = [1] * len(reconstructed_kernel.shape)
                        shape[group_axis] = w.shape[0]
                        reconstructed_kernel *= w.reshape(shape)
                    
                    (num_small_groups, total_groups, sparsity_ratio,
                     compression_rate, L1_kernel, L2_kernel, group_norms) = self._compute_group_sparsity_metrics(
                        reconstructed_kernel, group_axis, depth_kernel)
                    quasinorm_exponent = 2 / depth_kernel
                    min_weight_penalty = depth_kernel * np.sum(np.power(group_norms, quasinorm_exponent))
                    weight_l2_penalty = sum(np.sum(np.square(w)) for w in kernel_factors)
                    misalignment_kernel = weight_l2_penalty - min_weight_penalty
                    self.overall_num_small_groups += num_small_groups
                    self.overall_total_groups += total_groups
                    self.overall_misalignment += misalignment_kernel
                    self.overall_l2_sq += np.square(L2_kernel)
                    self.overall_l2_sq_factors += weight_l2_penalty
                    self.overall_min_penalty += min_weight_penalty

                    if self.save_metrics:
                        self.metrics_data = np.vstack((self.metrics_data,
                            [epoch+1, layer_index, 1, sparsity_ratio, compression_rate,
                             L1_kernel, L2_kernel, misalignment_kernel, weight_l2_penalty, min_weight_penalty]))
                    
                    if layer.use_bias and layer.factorize_bias:
                        num_bias = 1 if isinstance(layer, StrHadamardDense) else layer.depth
                        bias_factors = layer_weights[depth_kernel: depth_kernel + num_bias]
                        depth_bias = num_bias
                        reconstructed_bias = bias_factors[0].copy()
                        group_axis_bias = 0
                        for b in bias_factors[1:]:
                            reconstructed_bias *= b
                        (num_small_groups_b, total_groups_b, sparsity_ratio_b,
                         compression_rate_b, L1_bias, L2_bias, group_norms_bias) = self._compute_group_sparsity_metrics(
                            reconstructed_bias, group_axis_bias, depth_bias)
                        quasinorm_exponent_bias = 2 / depth_bias
                        min_bias_penalty = depth_bias * np.sum(np.power(group_norms_bias, quasinorm_exponent_bias))
                        bias_l2_penalty = sum(np.sum(np.square(b)) for b in bias_factors)
                        misalignment_bias = bias_l2_penalty - min_bias_penalty
                        self.overall_num_small_groups += num_small_groups_b
                        self.overall_total_groups += total_groups_b
                        self.overall_misalignment += misalignment_bias
                        self.overall_l2_sq += np.square(L2_bias)
                        self.overall_l2_sq_factors += bias_l2_penalty
                        self.overall_min_penalty += min_bias_penalty
                        if self.save_metrics:
                            self.metrics_data = np.vstack((self.metrics_data,
                                [epoch+1, layer_index, 0, sparsity_ratio_b, compression_rate_b,
                                 L1_bias, L2_bias, misalignment_bias, bias_l2_penalty, min_bias_penalty]))
                        total_groups_layer = total_groups + total_groups_b
                        num_small_total = num_small_groups + num_small_groups_b
                        joint_sparsity = num_small_total / total_groups_layer if total_groups_layer > 0 else float('inf')
                        total_misalignment = misalignment_kernel + misalignment_bias
                        total_l2_penalty_factors = weight_l2_penalty + bias_l2_penalty
                        if self.verbose == 2:
                            print(f"{layer.name}: sparsity kernel = {sparsity_ratio*100:.2f}%, min penalty kernel = {min_weight_penalty:.3e}, "
                                  f"bias = {sparsity_ratio_b*100:.2f}%, min penalty bias = {min_bias_penalty:.3e}, joint = {joint_sparsity*100:.2f}%, "
                                  f"CR = {compression_rate:.2f}, misalignment kernel = {misalignment_kernel:.3e}, misalignment bias = {misalignment_bias:.3e}, "
                                  f"joint misalignment = {total_misalignment:.3e}, l2 pen factors = {total_l2_penalty_factors:.3e}\n")
                    else:
                        if self.verbose == 2:
                            print(f"{layer.name}: sparsity kernel = {sparsity_ratio*100:.2f}%, min penalty kernel = {min_weight_penalty:.3e}, "
                                  f"bias = NA, min penalty bias = NA, joint = {sparsity_ratio*100:.2f}%, CR = {compression_rate:.2f}, "
                                  f"misalignment kernel = {misalignment_kernel:.3e}, bias = NA, l2 pen factors = {weight_l2_penalty:.3e}\n")

        overall_sparsity = self.overall_num_small_groups / self.overall_total_groups if self.overall_total_groups > 0 else float('inf')
        overall_compression_rate = self.overall_total_groups / (self.overall_total_groups - self.overall_num_small_groups) if self.overall_num_small_groups < self.overall_total_groups else 0
        overall_l2 = np.sqrt(self.overall_l2_sq)
        overall_l2_sq_factors = self.overall_l2_sq_factors
        if self.save_metrics:
            self.total_metrics_data = np.vstack((self.total_metrics_data,
                [epoch+1, overall_sparsity, overall_compression_rate, self.overall_misalignment, overall_l2, overall_l2_sq_factors, self.overall_min_penalty]))
        print(f"Total sparsity = {overall_sparsity*100:.2f}%, Total Compression rate = {overall_compression_rate:.2f}, Total L2 norm = {overall_l2:.3e}, "
              f"Total misalignment = {self.overall_misalignment:.3e}, Total sq. L2 factors = {overall_l2_sq_factors:.3e}, Total min penalty = {self.overall_min_penalty:.3e}\n###")



# LR print callback
class PrintLRCallback(tf.keras.callbacks.Callback):
    def __init__(self, verbose=1):
        super(PrintLRCallback, self).__init__()
        self.verbose = verbose
    #def __init__(self) -> None:
        #super().__init__()
        self.lr_history = np.empty((0, 2))

    def on_epoch_begin(self, epoch, logs=None):
        # Retrieve the learning rate from the optimizer
        lr = self.model.optimizer.lr

        # Check if the learning rate is a callable object or a static value
        if callable(lr):
            # If callable, evaluate it at the current iteration
            current_lr = lr(self.model.optimizer.iterations)
        else:
            # If static, use the value directly
            current_lr = lr

        # Get the value of the learning rate
        eps_lr = tf.keras.backend.get_value(current_lr)

        # Print the learning rate
        if self.verbose >= 1:
            print(f"\nEpoch {epoch+1}: Current learning rate = {eps_lr:.3e}")
        
        # Append current lr
        self.lr_history = np.append(self.lr_history, [[epoch + 1, eps_lr]], axis=0)
        
    def get_lr_history(self):
        return self.lr_history
    
# Terminate training run if accuracy is not above threshold after grace period
class TerminateBadRuns(tf.keras.callbacks.Callback):
    def __init__(self, grace=5, minacc=0.2, verbose=1):
        """
        Initialize the callback.
        
        Parameters:
        - grace: Number of epochs to wait before checking for early stopping conditions.
        - minacc: Accuracy threshold below which training is stopped after grace epochs
        """
        super().__init__()
        self.grace = grace
        self.minacc = minacc
        self.verbose = verbose

    def on_train_begin(self, logs=None):
        # Initialize variables at training start 
        self.stopped_epoch = 0
        self.accuracy_history = []
        self.current_accuracy = None
        self.max_current_accuracy = None
        # Ensure 'accuracy' is in model metrics
        #if 'accuracy' not in self.model.metrics_names:
        #    raise ValueError("Accuracy metric not found in model metrics.")

    def on_epoch_begin(self, epoch, logs=None):
        # Handle grace period at start of epoch
        if epoch+1 <= self.grace:
            self.stopped_epoch = 0
            if self.verbose >= 1:
                print(f"Grace period: epoch {epoch + 1} / {self.grace}\n")

    def on_epoch_end(self, epoch, logs=None):
        # Check if acc < minacc after grace period and stop training early in case
        self.current_accuracy = logs.get('accuracy')
        if self.current_accuracy is None:
            raise ValueError("Accuracy metric not available.")
        
        self.accuracy_history.append(self.current_accuracy)
        self.max_current_accuracy = max(self.accuracy_history)
        
        if epoch > self.grace and self.max_current_accuracy <= self.minacc:
            self.stopped_epoch = epoch
            self.model.stop_training = True

    def on_train_end(self, logs=None):
        # Terminated run message
        if self.stopped_epoch > 0:
            print(f"Epoch {self.stopped_epoch + 1}: Early termination due to accuracy not improving enough.")