import numpy as np
import tensorflow as tf
from tensorflow import keras
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util

class GroupConv2d(keras.layers.Layer):
    def __init__(self, k, n_out, h_input, h_output, strides=1, padding='VALID', activation='relu', name=None, **kwargs):
        
        super(GroupConv2d, self).__init__(name=name, **kwargs)
        
        self.k = k
        self.n_out = n_out
        self.h_input = h_input
        self.h_output = h_output
        self.strides = strides
        self.padding = padding
        self.activation = activation
        
    def build(self, input_shape):
        """
        input should be (BATCH,H,W,C)
        """
        
        self.n_in = input_shape[3]            
        
        if self.h_input == 'Z2':
            gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input=self.h_input, h_output=self.h_output, 
                                                                    in_channels=self.n_in, out_channels=self.n_out, 
                                                                    ksize=self.k)
            self.b = self.add_weight(shape=(self.n_out*4,), initializer="random_normal", dtype=tf.float32, trainable=True, name='biases')
        elif self.h_input == 'C4':
            gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input=self.h_input, h_output=self.h_output, 
                                                                    in_channels=self.n_in//4, out_channels=self.n_out, 
                                                                    ksize=self.k)
            self.b = self.add_weight(shape=(self.n_out*4,), initializer="random_normal", dtype=tf.float32, trainable=True, name='biases')
        elif self.h_input == 'D4':
            gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input=self.h_input, h_output=self.h_output, 
                                                                    in_channels=self.n_in//8, out_channels=self.n_out, 
                                                                    ksize=self.k)
            self.b = self.add_weight(shape=(self.n_out*8,), initializer="random_normal", dtype=tf.float32, trainable=True, name='biases')

        self.gconv_indices = gconv_indices
        self.gconv_shape_info = gconv_shape_info
        
        self.w = self.add_weight(shape=w_shape, initializer="random_normal", dtype=tf.float32, trainable=True, name='w')
        
    def call(self, inputs):
        
        a = gconv2d(input=inputs, filter=self.w, strides=self.strides, padding=self.padding, 
                    gconv_indices=self.gconv_indices, gconv_shape_info=self.gconv_shape_info)
        
        a = tf.math.add(a, self.b[tf.newaxis,tf.newaxis,tf.newaxis,:] )
    
        if self.activation == 'relu':
            return tf.keras.activations.relu(a)
        elif self.activation == 'sigmoid':
            return tf.keras.activations.sigmoid(a)
        elif self.activation == 'tanh':
            return tf.keras.activations.tanh(a)
        else:
            return a

    def get_config(self):
        """
        This function generates a config in order to save the model at its current state.

        A model using GroupConv2d layer(s) can then be saved using
            >> model.save('./model.h5')
        and loaded with
            >> model = tf.keras.models.load_model('./model.h5', custom_objects={'GroupConv2d': GroupConv2d})
        """

        config = super(GroupConv2d, self).get_config()
        config.update({"k": self.k,
                       "n_out": self.n_out,
                       "h_input": self.h_input,
                       "h_output": self.h_output,
                       "strides": self.strides,
                       "padding": self.padding,
                       "activation": self.activation})

        return config