# TODO: incorporate bias factorization in this custom layer
from keras import layers
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
try:
    from keras.layers.convolutional import Conv                     
except ImportError:
    from keras.layers.convolutional.base_conv import Conv   

class SparseConv(Conv):
    def __init__(self, *args, position_sparsity=-1, depth=2, multfac_initializer='ones', multfac_regularizer=None, la=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.position_sparsity = position_sparsity
        self.depth = depth
        self.multfac_initializer = tf.keras.initializers.get(multfac_initializer)
        self.factorize_bias = False

        if multfac_regularizer is None and self.kernel_regularizer is None and la is not None:
            self.multfac_regularizer = tf.keras.regularizers.L2((self.depth - 1) * la)
            self.kernel_regularizer = tf.keras.regularizers.L2(la)
        else:
            self.multfac_regularizer = multfac_regularizer

    def build(self, input_shape):
        super().build(input_shape)
        kernel_shape = self.kernel.shape
        multfac_shape = [1] * len(kernel_shape)
        multfac_shape[self.position_sparsity] = kernel_shape[self.position_sparsity]

        self.multfac = self.add_weight(
            name='multfac',
            shape=tuple(multfac_shape),
            initializer=self.multfac_initializer,
            regularizer=self.multfac_regularizer,
            trainable=True,
            dtype=self.dtype,
        )

    def convolution_op(self, inputs, kernel):
        if self.padding == "causal":
            tf_padding = "VALID"  # Causal padding handled in `call`.
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding

        modified_kernel = tf.multiply(kernel, tf.pow(x=tf.abs(self.multfac), y=(self.depth - 1)))

        return tf.nn.convolution(
            inputs,
            modified_kernel,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=list(self.dilation_rate),
            data_format=self._tf_data_format,
            name=self.__class__.__name__,
        )

    def get_config(self):
        config = super().get_config()
        config.update({
            'position_sparsity': self.position_sparsity,
            'depth': self.depth,
            'multfac_initializer': tf.keras.initializers.serialize(self.multfac_initializer),
            'multfac_regularizer': tf.keras.regularizers.serialize(self.multfac_regularizer)
        })
        return config

class SparseConv2D(SparseConv):
    def __init__(self, filters, kernel_size, la=None, position_sparsity=-1, depth=2, **kwargs):
        super(SparseConv2D, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            la=la,
            position_sparsity=position_sparsity,
            depth=depth,
            rank=2,
            **kwargs)
  

# --- StrConv family ---

class StrConv(Conv):
    def __init__(self, *args, depth=2, la=0,
                 multfac_initializer=tf.keras.initializers.Ones(),
                 factorize_bias=True,
                 position_sparsity=-1,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.depth = depth
        # Set the default kernel regularizer using la.
        self.la = la / self.depth
        self.kernel_regularizer = tf.keras.regularizers.L2(self.la)
        self.multfac_initializer = tf.keras.initializers.get(multfac_initializer)
        self.factorize_bias = factorize_bias
        self.position_sparsity = position_sparsity

    def build(self, input_shape):
        # Let the parent build create self.kernel (and self.bias, if applicable)
        super().build(input_shape)
        kernel_shape = self.kernel.shape.as_list()
        # Determine effective grouping dimension.
        effective_pos = (self.position_sparsity
                         if self.position_sparsity >= 0
                         else len(kernel_shape) + self.position_sparsity)
        group_size = kernel_shape[effective_pos]
        # Use the same effective strength for diagonal weights.
        diag_reg = tf.keras.regularizers.L2(self.la)
        self.diag_weights = []
        for i in range(2, self.depth + 1):
            dw = self.add_weight(
                name='diag_{}'.format(i),
                shape=(group_size,),
                initializer=self.multfac_initializer,
                regularizer=diag_reg,
                trainable=True,
                dtype=self.dtype)
            self.diag_weights.append(dw)
        if self.use_bias and self.factorize_bias:
            bias_reg = tf.keras.regularizers.L2(self.la)
            self.B1 = self.add_weight(
                name='B1',
                shape=(self.filters,),
                initializer=self.bias_initializer,
                regularizer=bias_reg,
                trainable=True,
                dtype=self.dtype)
            self.bias_factors = []
            for i in range(2, self.depth + 1):
                bf = self.add_weight(
                    name='B_{}'.format(i),
                    shape=(self.filters,),
                    initializer=self.multfac_initializer,
                    regularizer=bias_reg,
                    trainable=True,
                    dtype=self.dtype)
                self.bias_factors.append(bf)

    def convolution_op(self, inputs, kernel):
        # Stack and multiply the diagonal weights.
        stacked = tf.stack(self.diag_weights, axis=0)
        prod_diag = tf.reduce_prod(stacked, axis=0)
        kernel_shape = kernel.shape.as_list()
        effective_pos = (self.position_sparsity
                         if self.position_sparsity >= 0
                         else len(kernel_shape) + self.position_sparsity)
        diag_shape = [1] * len(kernel_shape)
        diag_shape[effective_pos] = kernel_shape[effective_pos]
        prod_diag_reshaped = tf.reshape(prod_diag, diag_shape)
        # Reconstruct the effective kernel.
        kernel_reconstructed = kernel * prod_diag_reshaped
        if self.padding == "causal":
            tf_padding = "VALID"
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding
        return tf.nn.convolution(
            inputs,
            kernel_reconstructed,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=list(self.dilation_rate),
            data_format=self._tf_data_format,
            name=self.__class__.__name__)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'position_sparsity': self.position_sparsity,
            'depth': self.depth,
            'la': self.la,
            'multfac_initializer': tf.keras.initializers.serialize(self.multfac_initializer),
            'factorize_bias': self.factorize_bias
        })
        return config

# Structured sparse convolution layer for 2D convolutions.
class StrConv2D(StrConv):
    def __init__(self, filters, kernel_size, la=0, position_sparsity=-1, depth=2, **kwargs):
        super(StrConv2D, self).__init__(filters=filters,
                                        kernel_size=kernel_size,
                                        la=la,
                                        position_sparsity=position_sparsity,
                                        depth=depth,
                                        rank=2,
                                        **kwargs)
  

