# imports
import math
import numpy as np
import tensorflow as tf

    
class TwiceTruncatedNormalInitializer(tf.keras.initializers.Initializer):
    def __init__(self, minprod, depth, n_input=None, gain=np.sqrt(2), std=None):
        self.minprod = minprod #if minprod is not None else 3e-3
        self.depth = depth
        self.gain = gain  # Set gain to sqrt(2)
        self.std = std
        self.n_input = n_input

    def __call__(self, shape, dtype=None):
        # Compute fan_in
        #fan_in = np.prod(shape[:-1])
        std = self.std if self.std is not None else np.sqrt(np.square(self.gain) / self.n_input)

        threshold = self.minprod ** (1 / self.depth)
        threshold_max = np.minimum(1, 2 * (self.gain / np.sqrt(self.n_input)) ** (1 / self.depth))

        matrix = np.random.normal(0, std, size=shape)

        while np.any(np.abs(matrix) < threshold) or np.any(np.abs(matrix) > threshold_max):
            too_small = np.abs(matrix) < threshold
            too_large = np.abs(matrix) > threshold_max
            redraw_indices = too_small | too_large  # Combine conditions for redrawing
            matrix[redraw_indices] = np.random.normal(0, std, size=np.sum(redraw_indices))

        return tf.convert_to_tensor(matrix, dtype=dtype)

    def get_config(self):  # To support serialization
        return {"minprod": self.minprod, "depth": self.depth, "std": self.std,
               "n_input": self.n_input, "gain": self.gain}
    
class ExactNormalFactorization(tf.keras.initializers.Initializer):
    def __init__(self, depth, n_input=None, gain=np.sqrt(2), std=None, max_terms=3):
        self.depth = depth
        self.gain = gain
        self.std = std
        self.n_input = n_input
        self.max_terms = max_terms

    def rademacher(self):
        return np.random.choice([-1, 1])
    
    def gamma_rv(self, shape, size=1):
        return np.random.gamma(shape, 1, size)
    
    def exact_gaussian_factor_sample(self, std=1.0):
        epsilon = self.rademacher()
        term_1 = np.log(2) / (2 * self.depth)
        G_0 = self.gamma_rv(1 / self.depth)
        
        correction_sum = 0.0
        for j in range(1, self.max_terms + 1):
            G_j = self.gamma_rv(1 / self.depth)
            correction_sum += (G_j / (2 * j + 1)) - (np.log(1 + 1 / j) / (2 * self.depth))
            
        exponent = term_1 - G_0 - correction_sum
        W_1 = epsilon * np.exp(exponent)
        scaling_factor = std ** (1 / self.depth)
        return W_1 * scaling_factor

    def __call__(self, shape, dtype=None):
        std = self.std if self.std is not None else np.sqrt(np.square(self.gain) / self.n_input)
        
        matrix = np.array([self.exact_gaussian_factor_sample(std) for _ in range(np.prod(shape))])
        matrix = matrix.reshape(shape)
        
        return tf.convert_to_tensor(matrix, dtype=dtype)

    def get_config(self):
        return {
            "depth": self.depth,
            "std": self.std,
            "n_input": self.n_input,
            "gain": self.gain,
            "max_terms": self.max_terms
        }

def equivar_initializer(depth, n_input, initialization):
    """
    Custom initializer for Hadamard factor weight matrices with HeNormal or HeUniform initialization.

    Args:
    - depth: Number of factor weight matrices.
    - n_input: Number of input units.
    - initialization: A tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform, 
                        tf.keras.initializers.LecunNormal, tf.keras.initializers.LecunUniform,
                        TwiceTruncatedNormalInitializer or ExactNormalFactorization object.

    Returns:
    - An initializer function that can be used in the HadamardDense layer.
    """
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_input <= 0:
        raise ValueError("n_input must be a positive integer.")

    if not isinstance(initialization, (tf.keras.initializers.HeNormal, tf.keras.initializers.HeUniform,\
                                       tf.keras.initializers.LecunUniform, tf.keras.initializers.LecunNormal,\
                                       tf.keras.initializers.Orthogonal,TwiceTruncatedNormalInitializer,\
                                       ExactNormalFactorization)):
        raise ValueError("Invalid initialization type. Use implemented tensorflow initializer object.")

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        # Calculate the standard deviation based HeNormal stdev of reconstructed weight
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        # Calculate the limit for the uniform distribution based on HeUniform of reconstructed weight
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        # Calculate the standard deviation based LecunNormal stdev of reconstructed weight
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        # Use a truncated normal distribution with mean 0 and calculated standard deviation
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)

    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        # Calculate the limit for the uniform distribution based on LecunUniform of reconstructed weight
        limit = math.sqrt(3) * ((1.0 / n_input) ** (1.0 / (2 * depth)))
        # Use a uniform distribution within specified limit
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        # Simply return orthogonal init
        return tf.keras.initializers.Orthogonal()
    
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        # Calculate equivar standard deviation based on HeNormal
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        # Use twice truncated normal distribution with mean 0 and equivar standard deviation
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)
    
    elif isinstance(initialization, ExactNormalFactorization):
        # Calculate target product standard deviation based on HeNormal
        std_dev = np.sqrt(2.0 / n_input)
        return ExactNormalFactorization(depth=depth, n_input=n_input, std=std_dev)


# Custom initializer function for convolutional Hadamard layers
def equivar_initializer_conv2d(depth, kernel_size, n_channels, initialization):
    if depth <= 0:
        raise ValueError("depth must be a positive integer.")
    if n_channels <= 0:
        raise ValueError("n_channels must be a positive integer.")

    kernel_height, kernel_width = kernel_size
    n_input = kernel_height * kernel_width * n_channels

    if isinstance(initialization, tf.keras.initializers.HeNormal):
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.LecunNormal):
        std_dev = (1.0 / n_input) ** (1.0 / (2 * depth))
        return tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=std_dev)
    elif isinstance(initialization, tf.keras.initializers.HeUniform):
        limit = math.sqrt(3) * ((2.0 / n_input) ** (1.0 / (2 * depth)))
    elif isinstance(initialization, tf.keras.initializers.LecunUniform):
        limit = math.sqrt(3) * ((3.0 / n_input) ** (1.0 / (2 * depth)))
        return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
    elif isinstance(initialization, tf.keras.initializers.Orthogonal):
        return tf.keras.initializers.Orthogonal()
    elif isinstance(initialization, TwiceTruncatedNormalInitializer):
        # Calculate equivar standard deviation based on HeNormal
        std_dev = (2.0 / n_input) ** (1.0 / (2 * depth))
        return TwiceTruncatedNormalInitializer(minprod = 3e-3, depth=depth, n_input=n_input, std=std_dev)
    elif isinstance(initialization, ExactNormalFactorization):
        # Calculate target product standard deviation based on HeNormal
        std_dev = np.sqrt(2.0 / n_input)
        return ExactNormalFactorization(depth=depth, n_input=n_input, std=std_dev)

    else:
        raise ValueError("Invalid initialization type.")
