import math
import numpy as np
import tensorflow as tf
from initializers import TwiceTruncatedNormalInitializer, equivar_initializer, equivar_initializer_conv2d

# Dense layer with hadamard product parametrization of depth 'depth'
class HadamardDense(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDense, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal() 
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) 
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
                                                  regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1
            for bias in self.other_biases:
                B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config

# refactored hadamard dense layer with outsourced initialization clutter
class HadamardDenseRefac(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones", use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDenseRefac, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)

    def build(self, input_shape):
        self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.get_initializer(input_shape[1], 1), regularizer=self.reg, trainable=True)
        self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.get_initializer(input_shape[1], i), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]

        if self.use_bias:
            if self.factorize_bias:
                self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.get_initializer(self.units, 1, is_bias=True), regularizer=self.reg, trainable=True)
                self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.get_initializer(self.units, i, is_bias=True), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed *= weight

        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed *= bias
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_initializer(self, dim, factor_number, is_bias=False):
        if self.init_type == "root":
            init_matrix = self.init(shape=(dim, self.units))
            if factor_number == 1:
                init_signs = tf.math.sign(init_matrix)
                root_abs_init = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
                return lambda *args, **kwargs: init_signs * root_abs_init
            else:
                return lambda *args, **kwargs: tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
        elif self.init_type == "ones":
            return self.init if factor_number == 1 else self.init_rest
        elif self.init_type == "equivar":
            return equivar_initializer(self.depth, dim, self.init)
        elif self.init_type == "vanilla":
            return self.init

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config


# Conv2D layer with unstructured Hadamard product parametrization of depth 'depth'
class HadamardConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=None, activation='linear', la=0, depth=2,
                 init=None, init_rest=None, init_type="ones", use_bias=True,
                 factorize_bias=True, padding='valid', **kwargs):
        super(HadamardConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth
        self.strides = strides if strides is not None else [1,1,1,1]
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.reg = tf.keras.regularizers.l2(self.la)
        self.padding = 'SAME' if padding == 'same' else 'VALID' #padding

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        n_channels = input_shape[-1]
        kernel_shape = (self.kernel_size[0], self.kernel_size[1], n_channels, self.filters)

        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            #self.init_factor = equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init),regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = equivar_initializer_conv2d(self.depth, self.kernel_size, n_channels, tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.init, regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_kernel = self.init(shape=kernel_shape)
            init_signs_w = tf.math.sign(init_kernel)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_kernel), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True)
            self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(init_signs_w * root_abs_init_w)
            for weight_var in self.other_weights:
                weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                bias_shape = (self.filters,)
                if self.init_type == "ones":
                        bias_init_config = self.init.get_config()
                        self.bias_init = self.init.__class__(**bias_init_config)
                        self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.bias_init, regularizer=self.reg, trainable=True)
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.filters, self.init)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, self.init),regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=equivar_initializer(self.depth, self.filters, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    self.init_bias_factor = self.init
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=bias_shape)
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True)
                    self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    for bias_var in self.other_biases:
                        bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.filters,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed = W_reconstructed * weight

        output = tf.nn.conv2d(inputs, W_reconstructed, strides=self.strides, padding=self.padding)

        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed = B_reconstructed * bias
                output = tf.nn.bias_add(output, B_reconstructed)
            else:
                output = tf.nn.bias_add(output, self.bias)

        return self.activation(output)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'padding': self.padding
        })
        return config



class HadamardConv2DRefac(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=None, activation='linear', la=0, depth=1,
                 init=None, init_rest=None, init_type="ones", use_bias=True,
                 factorize_bias=True, padding='valid', **kwargs):
        super(HadamardConv2DRefac, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides if strides is not None else (1, 1)
        self.activation = tf.keras.activations.get(activation)
        self.la = la / self.depth
        self.depth = depth
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.padding = 'SAME' if padding.lower() == 'same' else 'VALID'
        self.reg = tf.keras.regularizers.l2(self.la)

    def build(self, input_shape):
        n_channels = input_shape[-1]
        kernel_shape = (self.kernel_size[0], self.kernel_size[1], n_channels, self.filters)
        
        # Initialize kernels using a helper method
        self.U1 = self.add_weight(name='U1', shape=kernel_shape, initializer=self.get_initializer(kernel_shape, 1), regularizer=self.reg, trainable=True)
        self.other_weights = [self.add_weight(name='U{}'.format(i), shape=kernel_shape, initializer=self.get_initializer(kernel_shape, i), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]

        if self.use_bias:
            bias_shape = (self.filters,)
            if self.factorize_bias:
                self.B1 = self.add_weight(name='B1', shape=bias_shape, initializer=self.get_initializer(bias_shape, 1, True), regularizer=self.reg, trainable=True)
                self.other_biases = [self.add_weight(name='B{}'.format(i), shape=bias_shape, initializer=self.get_initializer(bias_shape, i, True), regularizer=self.reg, trainable=True) for i in range(2, self.depth + 1)]
            else:
                self.bias = self.add_weight(name='bias', shape=bias_shape, initializer='zeros', trainable=True)

    def call(self, inputs):
        W_reconstructed = self.U1
        for weight in self.other_weights:
            W_reconstructed *= weight

        output = tf.nn.conv2d(inputs, W_reconstructed, strides=self.strides, padding=self.padding)

        if self.use_bias:
            if self.factorize_bias:
                B_reconstructed = self.B1
                for bias in self.other_biases:
                    B_reconstructed *= bias
                output = tf.nn.bias_add(output, B_reconstructed)
            else:
                output = tf.nn.bias_add(output, self.bias)

        return self.activation(output)

    def get_initializer(self, shape, factor_number, is_bias=False):
        if self.init_type == "root":
            init_matrix = self.init(shape=shape)
            root_abs_init = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            if factor_number == 1:
                init_signs = tf.math.sign(init_matrix)
                # For the first layer, apply sign to the root transformation
                return lambda *args, **kwargs: init_signs * root_abs_init
            else:
                # For subsequent layers, use only the root transformation without sign
                return lambda *args, **kwargs: root_abs_init
        elif self.init_type == "equivar":
            return equivar_initializer_conv2d(self.depth, self.kernel_size, shape[-2], self.init)
        elif self.init_type == "vanilla":
            return self.init
        elif self.init_type == "ones":
            return self.init if factor_number == 1 else self.init_rest

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'padding': self.padding
        })
        return config


# Dense layer with shared Hadamard product parametrization of depth 'depth'
class HadamardDenseShared(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDenseShared, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        if not isinstance(depth, int) and (not isinstance(depth, float) or depth % 1 != 0):
            raise ValueError("depth must be an integer.")
        #depth = int(depth)
        self.depth = depth #tf.cast(depth, tf.int32)
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.regu = tf.keras.regularizers.l2(self.la)
        self.regv = tf.keras.regularizers.l2((self.depth-1)*self.la) 

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root' 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_u = equivar_initializer(self.depth, input_shape[1], self.init)
            self.init_v = equivar_initializer(self.depth, input_shape[1], self.init)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init_u, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_v, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U.assign(init_signs_w * root_abs_init_w)
            self.V.assign(root_abs_init_w)
            #for weight_var in self.other_weights:
            #    weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.bias_init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_rest, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init_bias_factor, regularizer=self.reg, trainable=True)
                    self.BV = [self.add_weight(name='BV'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                                                         regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='BV'.format(i), shape=(self.units,), initializer=self.init,\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer='zeros', regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer='zeros', regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.BU.assign(init_signs_b * root_abs_init_b)
                    self.BV.assign(root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = tf.multiply(self.U, tf.pow(x = self.V, y = (self.depth-1))) 
        #for weight in self.other_weights:
        #    W_reconstructed = W_reconstructed * weight
        #W_reconstructed = self.U if self.depth == 1 else tf.multiply(self.U, tf.reduce_prod([self.V] * (self.depth - 1), axis=0))

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = tf.multiply(self.BU, tf.pow(x = self.BV, y = (self.depth-1))) 
            #B_reconstructed = self.BU if self.depth == 1 else tf.multiply(self.BU, tf.reduce_prod([self.BV] * (self.depth - 1), axis=0))
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config    


# Dense Layer with Hadamard Power Parametrization of real-valued depth 'depth'
class HadamardDensePower(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, **kwargs):
        super(HadamardDensePower, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type
        self.regu = tf.keras.regularizers.l2(self.la)
        self.regv = tf.keras.regularizers.l2((self.depth-1)*self.la) 

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar":
            self.init_u = equivar_initializer(self.depth, input_shape[1], self.init)
            self.init_v = equivar_initializer(self.depth, input_shape[1], self.init)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init_u, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init_v, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=equivar_initializer(self.depth, input_shape[1], self.init),\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer=self.init,\
            #                                      regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            init_signs_w = tf.math.sign(init_matrix)
            root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U = self.add_weight(name='U', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regu, trainable=True)
            self.V = self.add_weight(name='V', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.regv, trainable=True)
            #self.other_weights = [self.add_weight(name='U{}'.format(i), shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
            self.U.assign(init_signs_w * root_abs_init_w)
            self.V.assign(root_abs_init_w)
            #for weight_var in self.other_weights:
            #    weight_var.assign(root_abs_init_w)

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.bias_init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_rest, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    self.init_bias_u = equivar_initializer(self.depth, self.units, self.init)
                    self.init_bias_v = equivar_initializer(self.depth, self.units, self.init)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init_bias_u, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init_bias_v, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer=self.init, regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer=self.init, regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,\
                    #                                     regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.BU = self.add_weight(name='BU', shape=(self.units,), initializer='zeros', regularizer=self.regu, trainable=True)
                    self.BV = self.add_weight(name='BV', shape=(self.units,), initializer='zeros', regularizer=self.regv, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.BU.assign(init_signs_b * root_abs_init_b)
                    self.BV.assign(root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        W_reconstructed = tf.multiply(self.U, tf.pow(x = tf.abs(self.V), y = (self.depth-1))) 
        #for weight in self.other_weights:
        #    W_reconstructed = W_reconstructed * weight

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = tf.multiply(self.BU, tf.pow(x = tf.abs(self.BV), y = (self.depth-1))) #self.BU * self.BV
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias
        })
        return config    

# Structured Sparse Dense layer (neuron-wise, i.e., all incoming weights to one unit + its bias form one group via shared overparametrization) 
class StrHadamardDense(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=True, groupsize=None, **kwargs):
        super(StrHadamardDense, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal #()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones #()
        self.init_type = init_type
        #self.reg = tf.keras.regularizers.l2(self.la)
        #self.groupsize = groupsize if groupsize is not None else 1.0 # equally sized groups
        #self.blowup_factor = tf.math.sqrt(self.groupsize + 1.0) if self.use_bias else tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        #self.blownupreg = tf.keras.regularizers.l2(self.blowup_factor * self.la) # using this regu for the grouped weights to achieve group size normalization

        self.reg = tf.keras.regularizers.l2(self.la)
        self.groupsize = groupsize if groupsize is not None else np.float(1.0) # equally sized groups
        self.blowup_factor = tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        self.blownupreg = tf.keras.regularizers.l2(np.float(self.blowup_factor * self.la)) # using this regu for the grouped weights to achieve group size normalization
        
        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar": # TODO: check the math of applying this init to FC(diag(diag(...))) structured overparametrization
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) # TODO: check if identical inits are created and replace by individual instances 
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor,regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, input_shape[1], self.init),regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            col_wise_norms = tf.norm(init_matrix, ord=2, axis=0)
            depth_th_roots = tf.pow(col_wise_norms, 1.0 / depth)
            scaled_init_matrix = init_matrix / tf.reshape(depth_th_roots, (-1, 1))
            #init_signs_w = tf.math.sign(init_matrix)
            #root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
            self.U1.assign(scaled_init_matrix)
            for weight_var in self.diag_weights:
                weight_var.assign(tf.reshape(depth_th_roots, (-1,1)))

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    #self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    #for bias_var in self.other_biases:
                    #    bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        stacked_diag_layers = tf.stack(self.diag_weights, axis=0)
        collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
        
        # Reshape grouping parameters for broadcasting and weight reconstruction
        grouping_dim = collapsed_diag_layers.shape[0]
        if grouping_dim == self.units:
            # Column grouping (output sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (1, self.units)) 
        else:
            # Row grouping (input sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (grouping_dim, 1))
        #diag_mat = tf.linalg.diag(collapsed_diag_layers)
        W_reconstructed =  self.U1 * collapsed_diag_layers # column-wise grouping via broadcasting #tf.linalg.matmul(self.U1, diag_mat)

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1 * dollapsed_diag_layers
            #for bias in self.other_biases:
            #    B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'groupsize': self.groupsize,
            'blowup_factor': self.blowup_factor
        })
        return config


# Dense layer with almost neuron-sparse structured parametrization (all outgoing weights from previous layer sans biases of current layer form one group)
class StrHadamardDenseV2(tf.keras.layers.Layer):
    def __init__(self, units=1, activation='linear', la=0, depth=2, init=None, init_rest=None, init_type="ones",
                 use_bias=True, factorize_bias=False, groupsize=None, **kwargs):
        super(StrHadamardDenseV2, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.factorize_bias = factorize_bias
        self.depth = depth
        self.la = la / self.depth # to normalize scale of induced regularizers
        self.init = init if init is not None else tf.keras.initializers.HeNormal() #()
        self.init_rest = init_rest if init_rest is not None else tf.keras.initializers.Ones()
        self.init_type = init_type if init_type is not None else 'ones'
        self.reg = tf.keras.regularizers.l2(self.la)
        self.groupsize = groupsize if groupsize is not None else np.float(1.0) # equally sized groups
        self.blowup_factor = tf.math.sqrt(self.groupsize)
        #self.blowup_factor = tf.math.sqrt(tf.cond(self.use_bias==True, tf.cast(self.groupsize + 1.0, dtype='float32'), tf.cast(self.groupsize, dtype='float32'))) # num incoming weights per unit (=hidden units in previous layer, + 1 if bias)
        self.blownupreg = tf.keras.regularizers.l2(np.float(self.blowup_factor * self.la)) # using this regu for the grouped weights to achieve group size normalization

        if self.init_type not in ['ones', 'equivar', 'root', 'vanilla']:
            raise ValueError("init_type must be one of ['ones', 'equivar', 'root', 'vanilla']")

    def build(self, input_shape):
        if self.init_type == "ones":
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            if len(range(2, self.depth+1)) > 0:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=self.init_rest, regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "equivar": # TODO: check the math of applying this init to FC(diag(diag(...))) structured overparametrization
            self.init_factor = equivar_initializer(self.depth, input_shape[1], self.init) # TODO: check if identical inits are created and replace by individual instances 
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init_factor,regularizer=self.reg, trainable=True)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=equivar_initializer(self.depth, input_shape[1], self.init),regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "vanilla":
            #self.init_factor = self.init #equivar_initializer(self.depth, input_shape[1], tf.keras.initializers.VarianceScaling())
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer=self.init, regularizer=self.reg, trainable=True)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer=self.init,regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
        elif self.init_type == "root":
            init_matrix = self.init(shape=(input_shape[1], self.units))
            row_wise_norms = tf.norm(init_matrix, ord=2, axis=1)
            depth_th_roots = tf.pow(row_wise_norms, 1.0 / depth)
            scaled_init_matrix = init_matrix / tf.reshape(depth_th_roots, (-1, 1))
            #init_signs_w = tf.math.sign(init_matrix)
            #root_abs_init_w = tf.math.pow(tf.math.abs(init_matrix), 1.0 / self.depth)
            self.U1 = self.add_weight(name='U1', shape=(input_shape[1], self.units), initializer='zeros', regularizer=self.reg, trainable=True)
            self.U1.assign(scaled_init_matrix)
            if self.depth > 1:
                self.diag_weights = [self.add_weight(name='U{}'.format(i), shape=input_shape[1], initializer='zeros', regularizer=self.blownupreg, trainable=True) for i in range(2, self.depth+1)]
                for weight_var in self.diag_weights:
                    weight_var.assign(tf.reshape(depth_th_roots, (-1,1)))

        if self.use_bias:
            if self.factorize_bias:
                if self.init_type == "ones":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.bias_init, regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init_rest, regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "equivar":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units, self.init)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=equivar_initializer(self.depth, self.units, self.init),regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "vanilla":
                    #self.init_bias_factor = equivar_initializer(self.depth, self.units,  tf.keras.initializers.VarianceScaling())
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer=self.init, regularizer=self.reg, trainable=True)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer=self.init,regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                elif self.init_type == "root":
                    bias_init_config = self.init.get_config()
                    self.bias_init = self.init.__class__(**bias_init_config)
                    init_vector = self.bias_init(shape=(self.units,))
                    init_signs_b = tf.math.sign(init_vector)
                    root_abs_init_b = tf.math.pow(tf.math.abs(init_vector), 1.0 / self.depth)
                    self.B1 = self.add_weight(name='B1', shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True)
                    self.B1.assign(init_signs_b * root_abs_init_b)
                    if self.depth > 1:
                        self.other_biases = [self.add_weight(name='B{}'.format(i), shape=(self.units,), initializer='zeros', regularizer=self.reg, trainable=True) for i in range(2, self.depth+1)]
                        for bias_var in self.other_biases:
                            bias_var.assign(root_abs_init_b)
            else:
                self.bias = self.add_weight(name='bias', shape=(self.units,), initializer='zeros', trainable=True)
        else:
            self.bias = None

    def call(self, inputs):
        # Reconstruct weights
        stacked_diag_layers = tf.stack(self.diag_weights, axis=0)
        collapsed_diag_layers = tf.reduce_prod(stacked_diag_layers, axis=0)
        # Reshape grouping parameters for broadcasting and weight reconstruction
        grouping_dim = collapsed_diag_layers.shape[0]
        if grouping_dim == self.units:
            # Column grouping (output sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (1, self.units)) 
        else:
            # Row grouping (input sparsity)
            collapsed_diag_layers = tf.reshape(collapsed_diag_layers, (grouping_dim, 1))
        #diag_mat = tf.linalg.diag(collapsed_diag_layers)
        W_reconstructed =  self.U1 * collapsed_diag_layers # row-wise grouping via broadcasting #tf.linalg.matmul(diag.mat, self.U1) 

        # Reconstruct bias if factorize_bias=True
        if self.use_bias and self.factorize_bias:
            B_reconstructed = self.B1
            for bias in self.other_biases:
                B_reconstructed = B_reconstructed * bias

        # Compute output
        output = inputs @ W_reconstructed
        if self.use_bias:
            if self.factorize_bias:
                output += B_reconstructed
            else:
                output += self.bias

        return self.activation(output)

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'units': self.units,
            'activation': self.activation,
            'la': self.la,
            'depth': self.depth,
            'init': self.init,
            'init_rest': self.init_rest,
            'init_type': self.init_type,
            'use_bias': self.use_bias,
            'factorize_bias': self.factorize_bias,
            'blowup_factor': self.blowup_factor,
            'groupsize': self.groupsize
        })
        return config

