import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D, Conv2DTranspose, DepthwiseConv2D, SeparableConv2D, BatchNormalization, Dropout, Activation, ReLU
# BatchNormalization._USE_V2_BEHAVIOR = False # if using @function
from tensorflow.keras.regularizers import l2
import numpy as np

def flops_act(shapeOut):
    return np.prod(shapeOut[1:])

def mem_trans(shapeIn, shapeOut):
    return np.prod(shapeIn[1:]) + np.prod(shapeOut[1:])

class Conv(Layer):
    def __init__(self, num_filters_out, kernel_size=3, strides=1, dilation=1, activation='relu', padding='same', kernel_initializer='he_normal', dropout_rate=0, use_bias=False, transposed=False, weight_decay=0, with_bn=False, with_act=True, conv_type='std', renorm=False, renorm_momentum=0.99, name='', **kwargs):
        super(Conv, self).__init__(name=name, **kwargs)
        assert(not (transposed and conv_type != 'std'))

        self.num_filters_out = num_filters_out
        self.kernel_size = kernel_size
        self.strides = strides
        self.dilation = dilation
        self.activation = activation
        self.padding = padding
        self.kernel_initializer = kernel_initializer
        self.dropout_rate = dropout_rate
        self.use_bias = use_bias,
        self.transposed = transposed
        self.weight_decay = weight_decay
        self.with_bn = with_bn
        self.with_act = with_act
        self.conv_type = conv_type
        self.renorm = renorm
        self.renorm_momentum = renorm_momentum

        self.with_dropout = dropout_rate > 0

        if weight_decay > 0:
            kernel_regularizer = l2(weight_decay)
        else:
            kernel_regularizer = None

        if self.with_dropout:
            self.drop = Dropout(dropout_rate, name=name + '_drop')

        if conv_type == 'std':
            if transposed:
                self.conv = Conv2DTranspose(num_filters_out, kernel_size, strides=strides, dilation_rate=dilation, padding=padding,
                        kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                        name=name)
            else:
                self.conv = Conv2D(num_filters_out, kernel_size, strides=strides, dilation_rate=dilation, padding=padding,
                        kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                        name=name)
            if with_bn:
                self.bn = BatchNormalization(name=name + '_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            if with_act:
                self.act = Activation(activation, name=name + '_act')
        elif conv_type == 'mobileV1':
            self.dw_conv = DepthwiseConv2D(kernel_size, strides=strides, dilation_rate=dilation, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_dw')
            if with_bn:
                self.dw_bn = BatchNormalization(name=name + '_dw_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            self.act_dw = ReLU(6., name=name + '_dw_relu')
            self.nin_conv = Conv2D(num_filters_out, 1, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_1x1')
            if with_bn:
                self.nin_bn = BatchNormalization(name=name + '_1x1_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            if with_act:
                self.nin_act = ReLU(6., name=name + '_1x1_relu')
        elif conv_type == 'mobileV2':
            expansion = 6
            # NOTE: in our network architecture the input and output have the same number of filters
            self.nin_conv = Conv2D(num_filters_out * expansion, 1, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_1x1')
            if with_bn:
                self.nin_bn = BatchNormalization(name=name + '_1x1_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            self.nin_act = ReLU(6., name=name + '_1x1_relu')
            self.dw_conv = DepthwiseConv2D(kernel_size, strides=strides, dilation_rate=dilation, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_dw')
            if with_bn:
                self.dw_bn = BatchNormalization(name=name + '_dw_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            self.dw_act = ReLU(6., name=name + '_dw_relu')
            self.lin_conv = Conv2D(num_filters_out, 1, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_lin')
            if with_bn:
                self.lin_bn = BatchNormalization(name=name + '_lin_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            if with_act:
                self.lin_act = ReLU(6., name=name + '_act')
        elif conv_type == 'mixture':
            self.sep_conv = SeparableConv2D(num_filters_out, kernel_size, strides=strides, dilation_rate=dilation, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_dws')
            if with_bn:
                self.sep_bn = BatchNormalization(name=name + '_dws_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            self.sep_act = ReLU(6., name=name + '_dws_relu')
            self.nin_conv = Conv2D(num_filters_out * 4, 1, padding=padding,
                    kernel_initializer=kernel_initializer, use_bias=use_bias, kernel_regularizer=kernel_regularizer,
                    name=name + '_1x1')
            if with_bn:
                self.nin_bn = BatchNormalization(name=name + '_1x1_bn', renorm=renorm, renorm_momentum=renorm_momentum)
            if with_act:
                self.nin_act = ReLU(6., name=name + '_1x1_relu')
        else:
            raise Exception()

    def call(self, inputs, training=False):
        if self.conv_type == 'std':
            output  = self.conv(inputs)
            if self.with_bn:
                output = self.bn(output, training=training)
            if self.with_act:
                output = self.act(output)
                if self.with_dropout:
                    output = self.drop(output)
        elif self.conv_type == 'mobileV1':
            # MobileNet V1
            output = self.dw_conv(inputs)
            if self.with_bn:
                output = self.bn_dw(output, training=training)
            output = self.act_dw(output)
            output = self.nin_conv(output)
            if self.with_bn:
                output = self.nin_bn(output, training=training)
            if self.with_act:
                output = self.nin_act(output)
                if self.with_dropout:
                    output = self.drop(output)
        elif self.conv_type == 'mobileV2':
            # MobileNet V2
            # 1x1 
            output = self.nin_conv(inputs)
            if self.with_bn:
                output = self.nin_bn(output, training=training)
            output = self.nin_act(output)
            # depth-wise
            output = self.dw_conv(output)
            if self.with_bn:
                output = self.dw_bn(output, training=training)
            output = self.dw_act(output)
            if self.with_dropout:
                output = self.drop(output)
            # linear
            output = self.lin_conv(output)
            if self.with_bn:
                output = self.lin_bn(output, training=training)
            if self.with_act:
                output = self.lin_act(output)
        elif self.conv_type == 'mixture':
            # Adaptative Inference Cost With Convolutional Neural Mixture Models
            output = self.sep_conv(inputs)
            if self.with_bn:
                output = self.sep_bn(output, training=training)
            output = self.sep_act(output)
            output = self.nin_conv(output)
            if self.with_bn:
                output = self.nin_bn(output, training=training)
            if self.with_act:
                output = self.nin_act(output)
                if self.with_dropout:
                    output = self.drop(output)
        else:
            raise Exception()

        return output

    def get_config(self):
        config = super(Conv, self).get_config()
        config.update({'num_filters_out': self.num_filters_out,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'dilation': self.dilation,
            'activation': self.activation,
            'padding': self.padding,
            'kernel_initializer': self.kernel_initializer,
            'dropout_rate': self.dropout_rate,
            'use_bias': self.use_bias,
            'transposed': self.transposed,
            'weight_decay': self.weight_decay,
            'with_bn': self.with_bn,
            'with_act': self.with_act,
            'conv_type': self.conv_type,
            'renorm': self.renorm,
            'renorm_momentum': self.renorm_momentum})
        return config

class Shuffle(Layer):
    def __init__(self, **kwargs):
        super(Shuffle, self).__init__(**kwargs)

    def build(self, input_shape):
        num_channels = input_shape[-1]
        self.indices = np.random.permutation(num_channels)

    def _get_indices(self):
        return self.indices

    def call(self, inputs):
        output = tf.gather(inputs, self.indices, axis=-1)
        return output
