from keras import layers
import tensorflow as tf
try:
    from keras.layers.convolutional import Conv                     
except ImportError:
    from keras.layers.convolutional.base_conv import Conv   

class SparseConv(Conv):
    def __init__(self, *args, position_sparsity=-1, depth=2, multfac_initializer='ones', multfac_regularizer=None, la=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.position_sparsity = position_sparsity
        self.depth = depth
        self.multfac_initializer = tf.keras.initializers.get(multfac_initializer)
        self.factorize_bias = False

        if multfac_regularizer is None and self.kernel_regularizer is None and la is not None:
            self.multfac_regularizer = tf.keras.regularizers.L2((self.depth - 1) * la)
            self.kernel_regularizer = tf.keras.regularizers.L2(la)
        else:
            self.multfac_regularizer = multfac_regularizer

    def build(self, input_shape):
        super().build(input_shape)
        kernel_shape = self.kernel.shape
        multfac_shape = [1] * len(kernel_shape)
        multfac_shape[self.position_sparsity] = kernel_shape[self.position_sparsity]

        self.multfac = self.add_weight(
            name='multfac',
            shape=tuple(multfac_shape),
            initializer=self.multfac_initializer,
            regularizer=self.multfac_regularizer,
            trainable=True,
            dtype=self.dtype,
        )

    def convolution_op(self, inputs, kernel):
        if self.padding == "causal":
            tf_padding = "VALID"  # Causal padding handled in `call`.
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding

        modified_kernel = tf.multiply(kernel, tf.pow(x=tf.abs(self.multfac), y=(self.depth - 1)))

        return tf.nn.convolution(
            inputs,
            modified_kernel,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=list(self.dilation_rate),
            data_format=self._tf_data_format,
            name=self.__class__.__name__,
        )

    def get_config(self):
        config = super().get_config()
        config.update({
            'position_sparsity': self.position_sparsity,
            'depth': self.depth,
            'multfac_initializer': tf.keras.initializers.serialize(self.multfac_initializer),
            'multfac_regularizer': tf.keras.regularizers.serialize(self.multfac_regularizer),
        })
        return config

class SparseConv2D(SparseConv):
    def __init__(self, filters, kernel_size, la=None, position_sparsity=-1, depth=2, **kwargs):
        super(SparseConv2D, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            la=la,
            position_sparsity=position_sparsity,
            depth=depth,
            rank=2,
            **kwargs)
