# imports
import math
import numpy as np
import tensorflow as tf
    
class TwiceTruncatedNormalInitializer(tf.keras.initializers.Initializer):
    def __init__(self, minprod, depth, n_input=None, gain=np.sqrt(2), std=None):
        self.minprod = minprod #if minprod is not None else 3e-3
        self.depth = depth
        self.gain = gain  # Set gain to sqrt(2)
        self.std = std
        self.n_input = n_input

    def __call__(self, shape, dtype=None):
        # Compute fan_in
        #fan_in = np.prod(shape[:-1])
        std = self.std if self.std is not None else np.sqrt(np.square(self.gain) / self.n_input)

        threshold = self.minprod ** (1 / self.depth)
        threshold_max = np.minimum(1, 2 * (self.gain / np.sqrt(self.n_input)) ** (1 / self.depth))

        matrix = np.random.normal(0, std, size=shape)

        while np.any(np.abs(matrix) < threshold) or np.any(np.abs(matrix) > threshold_max):
            too_small = np.abs(matrix) < threshold
            too_large = np.abs(matrix) > threshold_max
            redraw_indices = too_small | too_large  # Combine conditions for redrawing
            matrix[redraw_indices] = np.random.normal(0, std, size=np.sum(redraw_indices))

        return tf.convert_to_tensor(matrix, dtype=dtype)

    def get_config(self):  # To support serialization
        return {"minprod": self.minprod, "depth": self.depth, "std": self.std,
               "n_input": self.n_input, "gain": self.gain}

def equivar_initializer(depth, n_input, initialization):
    """
    Custom initializer for Hadamard factor weight matrices with HeNormal or HeUniform initialization.

    Args:
    - depth: Number of factor weight matrices.
    - n_input: Number of input units.
    - initialization: A tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform, 
                        tf.keras.initializers.LecunNormal or tf.keras.initializers.LecunUniform object.

    Returns:
    - An initializer function that can be used in the HadamardDense layer.
    """
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_input <= 0:
        raise ValueError("n_input must be a positive integer.")

    if not isinstance(initialization, (tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform,\
                                       tf.keras.initializers.LecunUniform, tf.keras.initializers.LecunNormal,\
                                       tf.keras.initializers.Orthogonal,TwiceTruncatedNormalInitializer)):
        raise ValueError("Invalid initialization type. Use implemented tensorflow initializer object.")

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        # Calculate the standard deviation based HeNormal stdev of reconstructed weight
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        # Calculate the limit for the uniform distribution based on HeUniform of reconstructed weight
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        # Calculate the standard deviation based LecunNormal stdev of reconstructed weight
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        # Calculate the limit for the uniform distribution based on LecunUniform of reconstructed weight
        limit = math.sqrt(3) * ((1.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        # Simply return orthogonal init
        return tf.keras.initializers.Orthogonal()
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        # Calculate equivar standard deviation based on HeNormal
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use twice truncated normal distribution with mean 0 and equivar standard deviation
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)


# Custom initializer function for convolutional Hadamard layers
def equivar_initializer_conv2d(depth, kernel_size, n_channels, initialization):
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_channels <= 0:
        raise ValueError("n_channels must be a positive integer.")

    kernel_height, kernel_width = kernel_size
    n_input = kernel_height * kernel_width * n_channels

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        limit = math.sqrt(3) * ((3.0 / n_input) ** (1.0 / (2 * depth)))
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        return tf.keras.initializers.Orthogonal()
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)

    else:
        raise ValueError("Invalid initialization type.")
