from keras.layers import BatchNormalization
from keras.layers import Layer
import keras.backend as K
from keras.utils import conv_utils
import tensorflow as tf
from keras.layers import initializers

class Activation(BatchNormalization):
    def __init__(self, act_mode, privacy=False, sgxutils=None, pool_window=None,strides=None, epsilon=0.0, **kwargs):
        super(Activation, self).__init__(**kwargs)

        if (act_mode not in ['bn', 'bnrelu', 'bnadd', 'bnzerorelu', 'bnrelupool']):
            print("Activation mode of operation error")
            assert False
        self.act_mode    = act_mode
        self.privacy     = privacy
        self.sgxutils    = sgxutils
        self.pool_window = pool_window
        self.strides     = strides

    def build(self, input_shape):
        BatchNormalization.build(self, input_shape)
        if self.privacy is True:
            self.sgxutils.setup_batchnormsp_dark(input_shape, 
                                            privacy=self.privacy, 
                                            eps=self.epsilon, 
                                            momentum=self.momentum)
        if len(list(input_shape)) == 4:
            self.mean_axis = (1, 2, 3)
        elif len(list(input_shape)) == 2:
            self.mean_axis = 1

    def compute_output_shape(self, input_shape):
        if self.act_mode is 'bnrelupool':
            left_w, right_w = self.pool_window 
            left_s, right_s = self.strides
            rows = conv_utils.conv_output_length(input_shape[1], left_w,  'valid', left_s)
            
            cols = conv_utils.conv_output_length(input_shape[2], right_w, 'valid', right_s)
            return (input_shape[0], rows, cols, input_shape[3])
        else:
            return input_shape

    def call(self, input, skip_input=None, training=None):
        if self.privacy is True:
            if skip_input is None:
                skip_input = input
            output, act_src =  self.sgxutils.batchnorm_dark(
                                           input=input,
                                           means=tf.reduce_mean(input, axis=self.mean_axis),
                                           skip_input=skip_input,
                                           training=training,
                                           act_mode=self.act_mode)
            return output
        
        x = BatchNormalization.call(self, input, training=training)

        if self.act_mode == 'bn':
            return x
        elif self.act_mode == 'bnrelu':
            return K.relu(x, max_value=6)
        elif self.act_mode == 'bnadd':
            if skip_input is None:
                print("Activation BN ADD skip input is none")
                assert False
            return x + skip_input
        elif self.act_mode == 'bnrelupool':
            return K.pool2d(K.relu(x), self.pool_window, strides=self.strides, pool_mode='max')
        elif self.act_mode == 'bnzerorelu':
            return K.relu(x)
        else:
            print("Activation mode of operation error")
            assert False


class ResNetBottom(Layer):
    def __init__(self, right_norm, privacy=False, sgxutils=None, use_bias=False, epsilon=0.0, bias_shape=1,**kwargs):
        super(ResNetBottom, self).__init__(**kwargs)
        self.right_norm  = right_norm
        self.privacy     = privacy
        self.sgxutils    = sgxutils
        self.left_bnorm  = None
        self.right_bnorm = None
        self.use_bias    = use_bias
        self.bias_shape  = bias_shape
        self.bias_l      = None
        self.bias_r      = None
        self.eps         = 0.0
        self.not_built   = True

    def build(self, input_shape):
        self.left_bnorm      = BatchNormalization(axis=3, name='left', epsilon=self.eps, input_shape=input_shape)
        self.left_bnorm.build(input_shape)

        if len(list(input_shape)) == 4:
            self.mean_axis = (1, 2)
        elif len(list(input_shape)) == 2:
            self.mean_axis = 1
        
        if self.use_bias and self.bias_l is None:
            bias_type = tf.float32
            bias_init = initializers.get('zeros')
            self.bias_l = self.add_weight(shape=[self.bias_shape],
                                          dtype=bias_type,
                                          initializer=bias_init,
                                          name='bias_l')


        if self.right_norm:
            self.right_bnorm = BatchNormalization(axis=3, name='right', epsilon=self.eps, input_shape=input_shape)
            self.right_bnorm.build(input_shape)

            if self.use_bias and self.bias_r is None:
                bias_type = tf.float32
                bias_init = initializers.get('zeros')
                self.bias_r = self.add_weight(shape=[self.bias_shape],
                                              dtype=bias_type,
                                              initializer=bias_init,
                                              name='bias_r')

        if self.privacy and self.sgxutils is not None:
            mode = "normal"
            bias_l = self.get_weights()[0]
            bias_r = bias_l
            if self.right_norm:
                mode   = "downsample"
                bias_r = self.get_weights()[1]

            if self.not_built:
                
                self.sgxutils.resnet_setup_bottom(mode=mode,
                                                  in_size=input_shape,
                                                  out_size=input_shape,
                                                  eps=self.eps,
                                                  momentum=0.0,
                                                  bias_data_l=bias_l,
                                                  bias_data_r=bias_r
                                                  )
                self.not_built = False


    def compute_output_shape(self, input_shape):
        left_shape = input_shape
        
        return left_shape

    def call(self, left, right=None, training=None):
        if self.privacy and self.sgxutils is not None:
            return self.sgxutils.resnet_bottom_op(left_in=left,
                                             right_in=right,
                                             mean_left  = tf.reduce_mean(left,  axis=self.mean_axis),
                                             mean_right = tf.reduce_mean(right, axis=self.mean_axis)
                                            )


        left_res = left
        if self.use_bias:
            left_res = K.bias_add(left, self.bias_l)

        left_res = self.left_bnorm(left_res, training=training)
        if self.right_norm:
            right_res = right
            if self.use_bias:
                right_res = K.bias_add(right, self.bias_r)

            right_res = self.right_bnorm(right_res, training=training)
            res       = left_res + right_res
        else:
            res       = left_res + right

        return K.relu(res)



class ResNetActivation(Layer):
    def __init__(self, act_mode, privacy=False, sgxutils=None, pool_window=[0,0],strides=[0,0], epsilon=0.0, use_bias=False, bias_shape=1, **kwargs):
        super(ResNetActivation, self).__init__(**kwargs)

        if (act_mode not in ['bnzerorelu', 'bnrelupool', 'bias_add']):
            print("Activation mode of operation error")
            assert False

        self.act_mode    = act_mode
        if (act_mode == 'bias_add'):
            assert(use_bias==True)
        self.privacy     = privacy
        if self.privacy:
            assert(use_bias)

        self.sgxutils    = sgxutils
        self.pool_window = pool_window
        self.strides     = strides
        self.use_bias    = use_bias
        self.bias_shape  = int(bias_shape)
        self.epsilon     = epsilon
        self.batch       = BatchNormalization(axis=3, epsilon=self.epsilon)
        self.bias        = None
        self.not_built   = True
    def build(self, input_shape):
        if self.act_mode is not 'bias_add':
            self.batch.build(input_shape)
        
        if len(list(input_shape)) == 4:
            self.mean_axis = (1, 2)
        elif len(list(input_shape)) == 2:
            self.mean_axis = 1
        if self.use_bias and self.bias is None:
            bias_type = tf.float32
            bias_init = initializers.get('zeros')
            self.bias = self.add_weight(shape=[self.bias_shape],
                                        dtype=bias_type,
                                        initializer=bias_init,
                                        name='bias')

        # setting up sgx    
        if self.privacy and self.sgxutils is not None:
            bias_data = self.get_weights()[0]
            output_shape = self.compute_output_shape(input_shape)

            in_size_list  = list(input_shape)
            out_size_list = list(output_shape)

            num_dimension = len(in_size_list)

            if (num_dimension == 2):
                in_size_list  = [in_size_list[0],  1, 1, in_size_list[1]]
                out_size_list = [out_size_list[0], 1, 1, out_size_list[1]]

            if self.not_built:
                self.sgxutils.resnet_setup_activation(mode=self.act_mode, 
                                                      in_size=in_size_list, 
                                                      out_size=out_size_list,
                                                      pool_window = self.pool_window,
                                                      pool_stride=self.strides,
                                                      eps=self.epsilon,
                                                      momentum=0.0,
                                                      bias_data=bias_data
                                                     )
                self.not_built = False

    def compute_output_shape(self, input_shape):
        if self.act_mode is 'bnrelupool':
            left_w, right_w = self.pool_window 
            left_s, right_s = self.strides
            rows = conv_utils.conv_output_length(input_shape[1], left_w,  'valid', left_s)
            
            cols = conv_utils.conv_output_length(input_shape[2], right_w, 'valid', right_s)
            return (input_shape[0], rows, cols, input_shape[3])
        else:
            return input_shape

    def call(self, input, training=None):
        if self.privacy and self.sgxutils is not None:
            return self.sgxutils.resnet_activation_op(input=input, 
                                               means=tf.reduce_mean(input, axis=self.mean_axis),
                                               act_mode=self.act_mode)




        x = input
        if self.use_bias:
            x = K.bias_add(input, self.bias)
        
        if self.act_mode == 'bnrelupool':
            x = self.batch.call(x, training=training)
            return K.pool2d(K.relu(x), self.pool_window, strides=self.strides, pool_mode='max')
        elif self.act_mode == 'bnzerorelu':
            x = self.batch.call(x, training=training)
            return K.relu(x)
        elif self.act_mode == 'bias_add':
            return x
        else:
            print("Activation mode of operation error")
            assert False