import math
import numpy as np
import tensorflow as tf
from initializers import TwiceTruncatedNormalInitializer, ExactNormalFactorization, equivar_initializer, equivar_initializer_conv2d

# Dense layer with hadamard product parametrization of depth 'depth'
class HadamardDense(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDense, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal() # TODO: should init not be passed as class not instance?
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla', 'exact']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla', 'exact']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) # TODO: replace by individual instances or if identical inits are created
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)
                
        elif self.init_type == "exact":
            self.init_factor = equivar_initializer(self.depth, input_shape[1], ExactNormalFactorization(depth=self.depth))
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), 
                initializer=equivar_initializer(self.depth, input_shape[1], ExactNormalFactorization(depth=self.depth)),
                regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)
                
                elif self.init_type == "exact":
                    self.init_bias_factor = equivar_initializer(self.depth, self.units, ExactNormalFactorization(depth=self.depth))
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), 
                       initializer=equivar_initializer(self.depth, self.units, ExactNormalFactorization(depth=self.depth)),
                       regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1
            for bias in self.other_biases:
                B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config

# refactored hadamard dense layer with outsourced initialization clutter
class HadamardDenseRefac(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones", use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDenseRefac, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)

    def build(self, input_shape):
        self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.get_initializer(input_shape[1], 1), regularizer=self.reg, trainable=True)
        self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.get_initializer(input_shape[1], i), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]

        if self.use_bias:
            if self.factorize_bias:
                self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.get_initializer(self.units, 1, is_bias=True), regularizer=self.reg, trainable=True)
                self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.get_initializer(self.units, i, is_bias=True), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed *= weight

        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed *= bias
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_initializer(self, dim, factor_number, is_bias=False):
        if self.init_type == "root":
            init_matrix = self.init(shape=(dim, self.units))
            if factor_number == 1:
                init_signs = tf.math.sign(init_matrix)
                root_abs_init = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
                return lambda *args, **kwargs: init_signs * root_abs_init
            else:
                return lambda *args, **kwargs: tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
        elif self.init_type == "ones":
            return self.init if factor_number == 1 else self.init_rest
        elif self.init_type == "equivar":
            return equivar_initializer(self.depth, dim, self.init)
        elif self.init_type == "vanilla":
            return self.init

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config


# Conv2D layer with unstructured Hadamard product parametrization of depth 'depth'
class HadamardConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=None, activation='linear', la=0, depth=2,
                 init=None, init_rest=None, init_type="ones", use_bias=True,
                 factorize_bias=True, padding='valid', **kwargs):
        super(HadamardConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth
        self.strides = strides if strides is not None else [1,1,1,1]
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)
        self.padding = 'SAME' if padding == 'same' else 'VALID' #padding

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla', 'exact']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla', 'exact']")

    def build(self, input_shape):
        n_channels = input_shape[-1]
        kernel_shape = (self.kernel_size[0], self.kernel_size[1], n_channels, self.filters)

        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            #self.init_factor = equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init),regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_kernel = self.init(shape=kernel_shape)
            init_signs_w = tf.math.sign(init_kernel)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_kernel), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)
        elif self.init_type == "exact":
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, ExactNormalFactorization(depth=self.depth)),regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, ExactNormalFactorization(depth=self.depth)),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]

        if self.use_bias:
            if self.factorize_bias:
                bias_shape = (self.filters,)
                if self.init_type == "ones":
                        bias_init_config = self.init.get_config()
                        self.bias_init = self.init.__class__(**bias_init_config)
                        self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.bias_init, regularizer=self.reg, trainable=True)
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.filters, self.init)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, self.init),regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    self.init_bias_factor = self.init
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=bias_shape)
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)
                elif self.init_type == "exact":
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, ExactNormalFactorization(depth=self.depth)),regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, ExactNormalFactorization(depth=self.depth)),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            else:
                self.bias = self.add_weight(name='bias', shape=(self.filters,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        output = tf.nn.conv2d(inputs, W_reconstructed, strides=self.strides, padding=self.padding)

        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed = B_reconstructed * bias
                output = tf.nn.bias_add(output, B_reconstructed)
            else:
                output = tf.nn.bias_add(output, self.bias)

        return self.activation(output)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'padding': self.padding
        })
        return config


class HadamardConv2DRefac(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=None, activation='linear', la=0, depth=1,
                 init=None, init_rest=None, init_type="ones", use_bias=True,
                 factorize_bias=True, padding='valid', **kwargs):
        super(HadamardConv2DRefac, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides if strides is not None else (1, 1)
        self.activation = tf.keras.activations.get(activation)
        self.la = la / self.depth
        self.depth = depth
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.padding = 'SAME' if padding.lower() == 'same' else 'VALID'
        self.reg = tf.keras.regularizers.l2(self.la)

    def build(self, input_shape):
        n_channels = input_shape[-1]
        kernel_shape = (self.kernel_size[0], self.kernel_size[1], n_channels, self.filters)
        
        # Initialize kernels using a helper method
        self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.get_initializer(kernel_shape, 1), regularizer=self.reg, trainable=True)
        self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.get_initializer(kernel_shape, i), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]

        if self.use_bias:
            bias_shape = (self.filters,)
            if self.factorize_bias:
                self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.get_initializer(bias_shape, 1, True), regularizer=self.reg, trainable=True)
                self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.get_initializer(bias_shape, i, True), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]
            else:
                self.bias = self.add_weight(name='bias', shape=bias_shape, initializer='zeros', trainable=True)

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed *= weight

        output = tf.nn.conv2d(inputs, W_reconstructed, strides=self.strides, padding=self.padding)

        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed *= bias
                output = tf.nn.bias_add(output, B_reconstructed)
            else:
                output = tf.nn.bias_add(output, self.bias)

        return self.activation(output)

    def get_initializer(self, shape, factor_number, is_bias=False):
        if self.init_type == "root":
            init_matrix = self.init(shape=shape)
            root_abs_init = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            if factor_number == 1:
                init_signs = tf.math.sign(init_matrix)
                # For the first layer, apply sign to the root transformation
                return lambda *args, **kwargs: init_signs * root_abs_init
            else:
                # For subsequent layers, use only the root transformation without sign
                return lambda *args, **kwargs: root_abs_init
        elif self.init_type == "equivar":
            return equivar_initializer_conv2d(self.depth, self.kernel_size, shape[-2], self.init)
        elif self.init_type == "vanilla":
            return self.init
        elif self.init_type == "ones":
            return self.init if factor_number == 1 else self.init_rest

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'padding': self.padding
        })
        return config


# Dense layer with shared Hadamard product parametrization of depth 'depth'
class HadamardDenseShared(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDenseShared, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        if not isinstance(depth, int) and (not isinstance(depth, float) or depth % 1 != 0):
            raise ValueError("depth must be an integer.")
        #depth = int(depth)
        self.depth = depth #tf.cast(depth, tf.int32)
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.regu = tf.keras.regularizers.l2(self.la)
        self.regv = tf.keras.regularizers.l2((self.depth-1)*self.la) 

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root' 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_u = equivar_initializer(self.depth, input_shape[1], self.init)
            self.init_v = equivar_initializer(self.depth, input_shape[1], self.init)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init_u, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_v, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U.assign(init_signs_w * root_abs_init_w)
            self.V.assign(root_abs_init_w)
            #for weight_var in self.other_weights:
            #    weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.bias_init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_rest, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.BV = [self.add_weight(name='BV'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='BV'.format(i), shape=(self.units,), initializer=self.init,\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer='zeros', regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer='zeros', regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.BU.assign(init_signs_b * root_abs_init_b)
                    self.BV.assign(root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = tf.multiply(self.U, tf.pow(x = self.V, y = (self.depth-1))) 
        #for weight in self.other_weights:
        #    W_reconstructed = W_reconstructed * weight
        #W_reconstructed = self.U if self.depth == 1 else tf.multiply(self.U, tf.reduce_prod([self.V] * (self.depth - 1), axis=0))

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = tf.multiply(self.BU, tf.pow(x = self.BV, y = (self.depth-1))) 
            #B_reconstructed = self.BU if self.depth == 1 else tf.multiply(self.BU, tf.reduce_prod([self.BV] * (self.depth - 1), axis=0))
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config    


# Dense Layer with Hadamard Power Parametrization of real-valued depth 'depth'
class HadamardDensePower(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDensePower, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.regu = tf.keras.regularizers.l2(self.la)
        self.regv = tf.keras.regularizers.l2((self.depth-1)*self.la) 

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_u = equivar_initializer(self.depth, input_shape[1], self.init)
            self.init_v = equivar_initializer(self.depth, input_shape[1], self.init)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init_u, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_v, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U.assign(init_signs_w * root_abs_init_w)
            self.V.assign(root_abs_init_w)
            #for weight_var in self.other_weights:
            #    weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.bias_init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_rest, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_u = equivar_initializer(self.depth, self.units, self.init)
                    self.init_bias_v = equivar_initializer(self.depth, self.units, self.init)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init_bias_u, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_bias_v, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer='zeros', regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer='zeros', regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.BU.assign(init_signs_b * root_abs_init_b)
                    self.BV.assign(root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = tf.multiply(self.U, tf.pow(x = tf.abs(self.V), y = (self.depth-1))) 
        #for weight in self.other_weights:
        #    W_reconstructed = W_reconstructed * weight

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = tf.multiply(self.BU, tf.pow(x = tf.abs(self.BV), y = (self.depth-1))) #self.BU * self.BV
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config    

# Structured Sparse Dense layer (neuron-wise, i.e., all incoming weights to one unit + its bias form one group via shared overparametrization) 
class StrHadamardDense(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=True, groupsize=None, **kwargs):
        super(StrHadamardDense, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal #()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones #()
        self.init_type = init_type
        #self.reg = tf.keras.regularizers.l2(self.la)
        #self.groupsize = groupsize if groupsize is not None else 1.0 # equally sized groups
        #self.blowup_factor = tf.math.sqrt(self.groupsize + 1.0) if self.use_bias else tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        #self.blownupreg = tf.keras.regularizers.l2(self.blowup_factor * self.la) # using this regu for the grouped weights to achieve group size normalization

        self.reg = tf.keras.regularizers.l2(self.la)
        self.groupsize = groupsize if groupsize is not None else np.float(1.0) # equally sized groups
        self.blowup_factor = tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        self.blownupreg = tf.keras.regularizers.l2(np.float(self.blowup_factor * self.la)) # using this regu for the grouped weights to achieve group size normalization
        
        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar": # TODO: check the math of applying this init to FC(diag(diag(...))) structured overparametrization
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) # TODO: check if identical inits are created and replace by individual instances 
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor,regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, input_shape[1], self.init),regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            col_wise_norms = tf.norm(init_matrix, ord=2, axis=0)
            depth_th_roots = tf.pow(col_wise_norms, 1.0 / depth)
            scaled_init_matrix = init_matrix / tf.reshape(depth_th_roots, (-1, 1))
            #init_signs_w = tf.math.sign(init_matrix)
            #root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(scaled_init_matrix)
            for weight_var in self.diag_weights:
                weight_var.assign(tf.reshape(depth_th_roots, (-1,1)))

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        stacked_diag_layers = tf.stack(self.diag_weights, axis=0)
        collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
        
        # Reshape grouping parameters for broadcasting and weight reconstruction
        grouping_dim = collapsed_diag_layers.shape[0]
        if grouping_dim == self.units:
            # Column grouping (output sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (1, self.units)) 
        else:
            # Row grouping (input sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (grouping_dim, 1))
        #diag_mat = tf.linalg.diag(collapsed_diag_layers)
        W_reconstructed =  self.U1 * collapsed_diag_layers # column-wise grouping via broadcasting #tf.linalg.matmul(self.U1, diag_mat)

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1 * dollapsed_diag_layers
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'groupsize': self.groupsize,
            'blowup_factor': self.blowup_factor
        })
        return config


# Dense layer with almost neuron-sparse structured parametrization (all outgoing weights from previous layer to one neuron sans biases of current layer form one group)
class StrHadamardDenseV2(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, groupsize=None, **kwargs):
        super(StrHadamardDenseV2, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal() #()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type if init_type is not None else 'ones'
        self.reg = tf.keras.regularizers.l2(self.la)
        self.groupsize = groupsize if groupsize is not None else np.float(1.0) # equally sized groups
        self.blowup_factor = tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        self.blownupreg = tf.keras.regularizers.l2(np.float(self.blowup_factor * self.la)) # using this regu for the grouped weights to achieve group size normalization

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            if len(range(2, self.depth+1)) > 0:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=self.init_rest, regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar": # TODO: check the math of applying this init to FC(diag(diag(...))) structured overparametrization
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) # TODO: check if identical inits are created and replace by individual instances 
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor,regularizer=self.reg, trainable=True)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=equivar_initializer(self.depth, input_shape[1], self.init),regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=self.init,regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            row_wise_norms = tf.norm(init_matrix, ord=2, axis=1)
            depth_th_roots = tf.pow(row_wise_norms, 1.0 / depth)
            scaled_init_matrix = init_matrix / tf.reshape(depth_th_roots, (-1, 1))
            #init_signs_w = tf.math.sign(init_matrix)
            #root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.U1.assign(scaled_init_matrix)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer='zeros', regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
                for weight_var in self.diag_weights:
                    weight_var.assign(tf.reshape(depth_th_roots, (-1,1)))

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                        for bias_var in self.other_biases:
                            bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        stacked_diag_layers = tf.stack(self.diag_weights, axis=0)
        collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
        # Reshape grouping parameters for broadcasting and weight reconstruction
        grouping_dim = collapsed_diag_layers.shape[0]
        if grouping_dim == self.units:
            # Column grouping (output sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (1, self.units)) 
        else:
            # Row grouping (input sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (grouping_dim, 1))
        #diag_mat = tf.linalg.diag(collapsed_diag_layers)
        W_reconstructed =  self.U1 * collapsed_diag_layers # row-wise grouping via broadcasting #tf.linalg.matmul(diag.mat, self.U1) 

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1
            for bias in self.other_biases:
                B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'blowup_factor': self.blowup_factor,
            'groupsize': self.groupsize
        })
        return config

   
############### old HadamardDiag and GroupHadamardDiag

# Diagonal layer building block
class SimplyConnected(tf.keras.layers.Layer):
    def __init__(self, la=0, multfac_initializer=tf.initializers.Ones):
        super(SimplyConnected, self).__init__()
        self.la = la
        self.multfac_initializer = multfac_initializer
        
    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], ),
            initializer=self.multfac_initializer,
            regularizer=tf.keras.regularizers.l2(self.la),
            trainable=True,
        )
        
    def call(self, inputs):
        return tf.math.multiply(inputs, self.w)
        
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'la': self.la
        })
        return config
    
# Grouping layer building block
class GroupConnected(tf.keras.layers.Layer):
    def __init__(self, group_idx=None, la=0, group_initializer=tf.keras.initializers.HeNormal):
        super(GroupConnected, self).__init__()
        self.la = la
        self.input_shapes = [len(gii) for gii in group_idx]
        self.group_idx = group_idx
        self.group_initializer = group_initializer
        
    def build(self, input_shape):
        self.w = [self.add_weight(
            shape=(inps, 1),
            initializer=self.group_initializer,
            regularizer=tf.keras.regularizers.l2(self.la),
            trainable=True) for inps in self.input_shapes]
        
    def call(self, inputs):
        gathered_inputs = [tf.gather(inputs, ind, axis = 1) for ind in self.group_idx]
        return tf.squeeze(tf.stack([tf.matmul(gathered_inputs[i], self.w[i]) 
                          for i in range(len(gathered_inputs))], axis=1), axis=-1)
                          
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'group_idx': self.group_idx,
            'la': self.la
        })
        return config
    
# Diagonal layers of length (depth-1) followed by FC output
class HadamardDiag(tf.keras.layers.Layer):    
    def __init__(self, units=1, la=0, depth=2, kernel_initializer=tf.keras.initializers.HeNormal, multfac_initializer=tf.initializers.Ones, use_bias=False, **kwargs):
        super(HadamardLayer, self).__init__(**kwargs)
        self.units = units
        self.depth = depth
        self.la = la
        self.reg = tf.keras.regularizers.l2(self.la)
        self.kernel_initializer = kernel_initializer
        self.multfac_initializer = multfac_initializer
        # self._name = name
        
    def build(self, input_shape):
        self.fc = tf.keras.layers.Dense(input_shape = input_shape, 
                                        units = self.units, 
                                        use_bias=use_bias,
                                        bias_regularizer=None, 
                                        activation=None, 
                                        kernel_regularizer=self.reg,
                                        kernel_initializer=self.kernel_initializer
                                        )
        # create list of diagonal layers
        self.diaglayers = [SimplyConnected(la=self.la, multfac_initializer=self.multfac_initializer) for x in range(0, self.depth-1)]
    
        # use sequential model class for diagonal block
        self.diagblock = tf.keras.Sequential(self.diaglayers)
    

    def call(self, input):
        return self.fc(self.diagblock(input))  

        
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'la': self.la,
            'depth': self.depth
        })
        return config

# Grouping layer followed by diagonal layers followed by FC output
class GroupHadamardLayerOld(tf.keras.layers.Layer):
    def __init__(self, units=1, group_idx=None, la=0, depth=2, init_group=tf.keras.initializers.HeNormal, multfac_initializer=tf.keras.initializers.Ones, **kwargs):
        super(GroupHadamardLayerOld, self).__init__(**kwargs)
        self.units = units
        # self._name = name
        self.depth = depth
        self.la = la / self.depth
        self.reg = tf.keras.regularizers.l2(self.la)
        self.group_idx = group_idx
        self.init_group = init_group
        self.multfac_initializer = multfac_initializer
        
    def build(self, input_shape):
        self.fc = tf.keras.layers.Dense(input_shape = input_shape, 
                                        units = self.units, 
                                        use_bias=False, 
                                        bias_regularizer=None, 
                                        activation=None, 
                                        kernel_regularizer=self.reg,
                                        kernel_initializer=self.multfac_initializer
                                        )
        self.gc = GroupConnected(group_idx=self.group_idx, la=self.la, group_initializer=self.init_group)
        self.diaglayers = [SimplyConnected(la=self.la, multfac_initializer=self.multfac_initializer) for x in range(0, (self.depth-2))]
        self.diagblock = tf.keras.Sequential(self.diaglayers)

    def call(self, input):
        return self.fc(self.diagblock(self.gc(input)))

        
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'group_idx': self.group_idx,
            'la': self.la,
            'depth': self.depth
        })
        return config
    
import tensorflow as tf


class GroupHadamardLayer(tf.keras.layers.Layer):
    def __init__(self, units=1, group_idx=None, la=0, depth=2, 
                 init_group=tf.keras.initializers.HeNormal(), 
                 multfac_initializer=tf.keras.initializers.Ones(), **kwargs):
        super(GroupHadamardLayer, self).__init__(**kwargs)
        self.units = units
        self.group_idx = group_idx
        self.depth = depth
        # Effective regularization strength: la/depth.
        self.la = la / depth  
        self.init_group = init_group
        self.multfac_initializer = multfac_initializer

    def build(self, input_shape):
        reg = tf.keras.regularizers.l2(self.la)
        input_dim = input_shape[-1]
        
        # Unified weight vector for all input features (u).
        self.w = self.add_weight(
            shape=(input_dim, 1),
            initializer=self.init_group,
            regularizer=reg,
            trainable=True,
            name="w"
        )
        
        num_groups = len(self.group_idx)
        # Only need depth-2 group scalars (the fc weight counts as the last factor,
        # and the unified weight is the first factor).
        if self.depth > 2:
            self.group_scalars = self.add_weight(
                shape=(num_groups, self.depth - 2),
                initializer=tf.keras.initializers.Ones(),
                regularizer=reg,
                trainable=True,
                name="group_scalars"
            )
        else:
            self.group_scalars = None
        
        # Final dense mapping from groups to output units (counts as the last factor).
        self.fc_weight = self.add_weight(
            shape=(num_groups, self.units),
            initializer=self.multfac_initializer,
            regularizer=reg,
            trainable=True,
            name="fc_weight"
        )
        super(GroupHadamardLayer, self).build(input_shape)

    def call(self, inputs):
        # Reconstruct the effective weight matrix W_eff of shape (input_dim, units)
        effective_weight_parts = []
        indices_parts = []
        
        for g, grp in enumerate(self.group_idx):
            grp = tf.convert_to_tensor(grp, dtype=tf.int32)
            # Get unified weight subvector for group g.
            w_g = tf.gather(self.w, grp)  # shape: (|grp|, 1)
            # Compute additional group scalar factor if available.
            if self.group_scalars is not None:
                scale_g = tf.reduce_prod(self.group_scalars[g, :])
            else:
                scale_g = 1.0
            # Multiply the unified weight for group g by its group scalar.
            w_eff_g = w_g * scale_g  # shape: (|grp|, 1)
            # Multiply by the corresponding fc weight row.
            fc_g = self.fc_weight[g, :]  # shape: (units,)
            # Broadcasting: (|grp|, 1) * (units,) -> (|grp|, units)
            w_eff_g = w_eff_g * fc_g  
            effective_weight_parts.append(w_eff_g)
            indices_parts.append(grp)
        
        # Stitch together the effective weight rows to form W_eff with shape (input_dim, units).
        W_eff = tf.dynamic_stitch(indices_parts, effective_weight_parts)
        
        # Standard fully-connected layer processing.
        output = tf.matmul(inputs, W_eff)
        return output

    def get_config(self):
        config = super(GroupHadamardLayer, self).get_config()
        config.update({
            'units': self.units,
            'group_idx': self.group_idx,
            'la': self.la * self.depth,  # return original la value
            'depth': self.depth,
            'init_group': tf.keras.initializers.serialize(self.init_group),
            'multfac_initializer': tf.keras.initializers.serialize(self.multfac_initializer),
        })
        return config

    
# grouping layer followed by FC output
class DiffGroupLasso(tf.keras.layers.Layer):
    def __init__(self, units=1, group_idx=None, la=0, kernel_initializer=tf.initializers.Ones, multfac_initializer=tf.keras.initializers.HeNormal, **kwargs):
        super(DiffGroupLasso, self).__init__(**kwargs)
        self.units = units
        self.la = la / 2.0
        self.reg = tf.keras.regularizers.l2(self.la)
        # self._name = name
        self.group_idx = group_idx
        self.kernel_initializer = kernel_initializer
        self.multfac_initializer = multfac_initializer
        
    def build(self, input_shape):
        if self.group_idx is None:
            self.fc = tf.keras.layers.Dense(units = 1, 
                                            use_bias=False, 
                                            bias_regularizer=None, 
                                            activation=None, 
                                            kernel_regularizer=self.reg,
                                            kernel_initializer=self.kernel_initializer
                                            )
            self.gc = tf.keras.layers.Dense(input_shape = input_shape, 
                                        units = 1, 
                                        use_bias=False, 
                                        bias_regularizer=None, 
                                        activation=None, 
                                        kernel_regularizer=self.reg,
                                        kernel_initializer=self.multfac_initializer
                                        )
        else:
            self.fc = tf.keras.layers.Dense(input_shape = input_shape, 
                                            units = self.units, 
                                            use_bias=False, 
                                            bias_regularizer=None, 
                                            activation=None, 
                                            kernel_regularizer=self.reg,
                                            kernel_initializer=self.kernel_initializer
                                            )
            self.gc = GroupConnected(group_idx=self.group_idx, la=self.la, multfac_initializer=self.multfac_initializer)
            
    def call(self, input):
        return self.fc(self.gc(input))
        
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'group_idx': self.group_idx,
            'la': self.la
        })
        return config