# imports
import math
import numpy as np
import tensorflow as tf
from layers import HadamardDense, HadamardConv2D, SparseConv2D

# Keras callback for printing sparsity metrics for Hadamard layers
class HadamardCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold=np.finfo(np.float32).eps, save_metrics=False, verbose=2):
        super(HadamardCallback, self).__init__()
        self.threshold = threshold
        self.overall_num_small_values = 0
        self.overall_total_elements = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0
        self.save_metrics = save_metrics
        self.metrics_data = np.empty((0, 8))
        self.total_metrics_data = np.empty((0,5))
        self.verbose = verbose

        if self.verbose not in [0,1,2]:
            raise ValueError("verbose must be one of 0,1,2")

    def on_epoch_end(self, epoch, logs=None):
        model = self.model
        layers = model.layers
        if self.verbose != 0:
          print(f"\n###\nEpoch {epoch+1}:")

        self.overall_num_small_values = 0
        self.overall_total_elements = 0
        self.overall_misalignment = 0
        self.overall_l2_sq = 0

        for layer_index, layer in enumerate(layers):
            if isinstance(layer, (HadamardDense, HadamardConv2D, SparseConv2D)):
                layer_weights = layer.get_weights()
                weight_factors = [w for w in layer_weights if len(w.shape) >= 2]
                depth = len(weight_factors)
                reconstructed_matrix = np.ones_like(weight_factors[0])
                for w in weight_factors:
                    reconstructed_matrix *= w

                num_small_values_w, total_elements_w, sparsity_ratio_w, compression_rate_w, L1_w, L2_w = self._compute_sparsity_compression_norms(reconstructed_matrix)
                quasinorm_exponent = 2 / depth
                min_weight_penalty = depth * np.sum(np.power(np.abs(reconstructed_matrix), quasinorm_exponent))
                weight_l2_penalty = 0
                for w in weight_factors:
                    weight_l2_penalty += np.sum(np.square(w))
                misalignment_w = weight_l2_penalty - min_weight_penalty
                self.overall_num_small_values += num_small_values_w
                self.overall_total_elements += total_elements_w
                self.overall_misalignment += misalignment_w
                self.overall_l2_sq += np.square(L2_w)

                compression_rate_w_str = "{:.2f}".format(compression_rate_w) if compression_rate_w != 0 else 'NA'

                if self.save_metrics:  # Conditionally save based on save_metrics
                    # 1 = weight object
                    if isinstance(layer, (SparseConv2D)):
                        misalignment_w = -1
                    #weight_l2_penalty = 0
                    self.metrics_data = np.vstack((self.metrics_data, [epoch+1, layer_index, 1, sparsity_ratio_w, compression_rate_w, L1_w, L2_w, misalignment_w]))

                if layer.use_bias and layer.factorize_bias:
                    bias_factors = [w for w in layer_weights if len(w.shape) == 1]
                    reconstructed_bias = np.ones_like(bias_factors[0])
                    for b in bias_factors:
                        reconstructed_bias *= b

                    num_small_values_b, total_elements_b, sparsity_ratio_b, compression_rate_b, L1_b, L2_b = self._compute_sparsity_compression_norms(reconstructed_bias)
                    min_bias_penalty = depth * np.sum(np.power(np.abs(reconstructed_bias), quasinorm_exponent))
                    bias_l2_penalty = 0
                    for b in bias_factors:
                        bias_l2_penalty += np.sum(np.square(b))
                    misalignment_b = bias_l2_penalty - min_bias_penalty
                    self.overall_num_small_values += num_small_values_b
                    self.overall_total_elements += total_elements_b
                    self.overall_misalignment += misalignment_b
                    self.overall_l2_sq += np.square(L2_b)

                    if self.save_metrics:  # Conditionally save based on save_metrics
                        # 0 = bias object
                        if isinstance(layer, (SparseConv2D)):
                            bias_l2_penalty = -1
                            misalignment_b = -1
                        self.metrics_data = np.vstack((self.metrics_data, [epoch+1, layer_index, 0, sparsity_ratio_b, compression_rate_b, L1_b, L2_b, misalignment_b]))

                    total_elements = total_elements_w + total_elements_b
                    num_small_values = num_small_values_w + num_small_values_b
                    sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
                    compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0
                    compression_rate_str = "{:.2f}".format(compression_rate) if compression_rate != 0 else 'NA'
                    total_misalignment = misalignment_w + misalignment_b
                    if isinstance(layer, (SparseConv2D)):
                        total_misalignment = -1
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = {sparsity_ratio_b * 100:.2f}%, joint = {sparsity_ratio * 100:.2f}%, CR = {compression_rate_str}, misalignment weights = {misalignment_w:.3e}, biases = {misalignment_b:.3e}, joint = {total_misalignment:.3e}\n")

                else:
                    if self.verbose == 2:
                      print(f"{layer.name}: sparsity weights = {sparsity_ratio_w * 100:.2f}%, biases = NA, joint = {sparsity_ratio_w * 100:.2f}%, CR = {compression_rate_w_str}, misalignment weights = {misalignment_w:.3e}, biases = NA\n")


        # Compute and output overall metrics
        overall_sparsity, overall_compression_rate = self._compute_overall_metrics(self.overall_num_small_values, self.overall_total_elements)
        if isinstance(layer, (SparseConv2D)):
            self.overall_misalignment = -1
        overall_misalignment = self.overall_misalignment
        overall_l2 = np.sqrt(self.overall_l2_sq)

        if self.save_metrics:  # Conditionally save based on save_metrics
                    self.total_metrics_data = np.vstack((self.total_metrics_data, [epoch+1, overall_sparsity, overall_compression_rate, overall_misalignment, overall_l2]))

        overall_compression_rate_str = "{:.2f}".format(overall_compression_rate) if overall_compression_rate != 0 else 'NA'
        if self.verbose != 0:
          print(f"Total sparsity = {overall_sparsity * 100:.2f}%, Total Compression rate = {overall_compression_rate_str}, Total L2 norm = {overall_l2:.3e}, Total misalignment = {overall_misalignment:.3e}\n###")

    def _compute_sparsity_compression_norms(self, reconstructed_array):
        num_small_values = np.sum(np.abs(reconstructed_array) < self.threshold)
        total_elements = np.prod(reconstructed_array.shape)
        sparsity_ratio = num_small_values / total_elements if total_elements > 0 else float('inf')
        compression_rate = total_elements / (total_elements - num_small_values) if num_small_values < total_elements else 0

        # Compute L1 and L2 Norms
        L1_norm = np.sum(np.abs(reconstructed_array))
        L2_norm = np.sqrt(np.sum(np.square(reconstructed_array)))
        
        return num_small_values, total_elements, sparsity_ratio, compression_rate, L1_norm, L2_norm

    def _compute_overall_metrics(self, overall_num_small_values, overall_total_elements):
        overall_sparsity = overall_num_small_values / overall_total_elements if overall_total_elements > 0 else float('inf')
        overall_compression_rate = overall_total_elements / (overall_total_elements - overall_num_small_values) if overall_num_small_values < overall_total_elements else 0
        return overall_sparsity, overall_compression_rate

# LR print callback
class PrintLRCallback(tf.keras.callbacks.Callback):
    def __init__(self) -> None:
        super().__init__()
        self.lr_history = np.empty((0, 2))

    def on_epoch_begin(self, epoch, logs=None):
        # Retrieve the learning rate from the optimizer
        lr = self.model.optimizer.lr

        # Check if the learning rate is a callable object or a static value
        if callable(lr):
            # If callable, evaluate it at the current iteration
            current_lr = lr(self.model.optimizer.iterations)
        else:
            # If static, use the value directly
            current_lr = lr

        # Get the value of the learning rate
        eps_lr = tf.keras.backend.get_value(current_lr)

        # Print the learning rate
        print(f"\nEpoch {epoch+1}: Current learning rate = {eps_lr:.3e}")
        
        # Append current lr
        self.lr_history = np.append(self.lr_history, [[epoch + 1, eps_lr]], axis=0)
        
    def get_lr_history(self):
        return self.lr_history
    
# Terminate training run if accuracy is not above threshold after grace period
class TerminateBadRuns(tf.keras.callbacks.Callback):
    def __init__(self, grace=5, minacc=0.2):
        """
        Initialize the callback.
        
        Parameters:
        - grace: Number of epochs to wait before checking for early stopping conditions.
        - minacc: Accuracy threshold below which training is stopped after grace epochs
        """
        super().__init__()
        self.grace = grace
        self.minacc = minacc

    def on_train_begin(self, logs=None):
        # Initialize variables at training start 
        self.stopped_epoch = 0
        self.accuracy_history = []
        self.current_accuracy = None
        self.max_current_accuracy = None
        # Ensure 'accuracy' is in model metrics
        #if 'accuracy' not in self.model.metrics_names:
        #    raise ValueError("Accuracy metric not found in model metrics.")

    def on_epoch_begin(self, epoch, logs=None):
        # Handle grace period at start of epoch
        if epoch+1 <= self.grace:
            self.stopped_epoch = 0
            print(f"Grace period: epoch {epoch + 1} / {self.grace}\n")

    def on_epoch_end(self, epoch, logs=None):
        # Check if acc < minacc after grace period and stop training early in case
        self.current_accuracy = logs.get('accuracy')
        if self.current_accuracy is None:
            raise ValueError("Accuracy metric not available.")
        
        self.accuracy_history.append(self.current_accuracy)
        self.max_current_accuracy = max(self.accuracy_history)
        
        if epoch > self.grace and self.max_current_accuracy <= self.minacc:
            self.stopped_epoch = epoch
            self.model.stop_training = True

    def on_train_end(self, logs=None):
        # Terminated run message
        if self.stopped_epoch > 0:
            print(f"Epoch {self.stopped_epoch + 1}: Early termination due to accuracy not improving enough.")