from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import os
import warnings
import numpy as np
import keras as k

import keras.backend as backend
import keras.layers as layers
import keras.models as models
import keras.utils as keras_utils

from keras.engine.topology import Layer
from keras_applications import correct_pad
from keras_applications import get_submodules_from_kwargs
from keras_applications import imagenet_utils
from keras_applications.imagenet_utils import decode_predictions
from keras_applications.imagenet_utils import _obtain_input_shape

from python.slalom.Activation import Activation

# TODO Change path to v1.1
BASE_WEIGHT_PATH = ('https://github.com/JonathanCMitchell/mobilenet_v2_keras/'
                    'releases/download/v1.1/')


def preprocess_input(x, **kwargs):
    """Preprocesses a numpy array encoding a batch of images.
    # Arguments
        x: a 4D numpy array consists of RGB values within [0, 255].
    # Returns
        Preprocessed array.
    """
    return imagenet_utils.preprocess_input(x, mode='tf', **kwargs)


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class ReLU6(Layer):
    def __init__(self, max_value=6, **kwargs):
        super(ReLU6, self).__init__(**kwargs)

    def call(self, input):
        return backend.relu(input, max_value=6)

def MobileNetV2(input_shape=None,
                alpha=1.0,
                include_top=True,
                weights='imagenet',
                input_tensor=None,
                pooling=None,
                classes=1000,
                privacy=False,
                sgxutils=None,
                full_enclave=False,
                enclave_list=None,
                **kwargs):

    if not (weights in {'imagenet', None} or os.path.exists(weights)):
        raise ValueError('The `weights` argument should be either '
                         '`None` (random initialization), `imagenet` '
                         '(pre-training on ImageNet), '
                         'or the path to the weights file to be loaded.')

    if weights == 'imagenet' and include_top and classes != 1000:
        raise ValueError('If using `weights` as `"imagenet"` with `include_top` '
                         'as true, `classes` should be 1000')

    # Determine proper input shape and default size.
    # If both input_shape and input_tensor are used, they should match
    if input_shape is not None and input_tensor is not None:
        try:
            is_input_t_tensor = backend.is_keras_tensor(input_tensor)
        except ValueError:
            try:
                is_input_t_tensor = backend.is_keras_tensor(
                    keras_utils.get_source_inputs(input_tensor))
            except ValueError:
                raise ValueError('input_tensor: ', input_tensor,
                                 'is not type input_tensor')
        if is_input_t_tensor:
            if backend.image_data_format == 'channels_first':
                if backend.int_shape(input_tensor)[1] != input_shape[1]:
                    raise ValueError('input_shape: ', input_shape,
                                     'and input_tensor: ', input_tensor,
                                     'do not meet the same shape requirements')
            else:
                if backend.int_shape(input_tensor)[2] != input_shape[1]:
                    raise ValueError('input_shape: ', input_shape,
                                     'and input_tensor: ', input_tensor,
                                     'do not meet the same shape requirements')
        else:
            raise ValueError('input_tensor specified: ', input_tensor,
                             'is not a keras tensor')

    # If input_shape is None, infer shape from input_tensor
    if input_shape is None and input_tensor is not None:

        try:
            backend.is_keras_tensor(input_tensor)
        except ValueError:
            raise ValueError('input_tensor: ', input_tensor,
                             'is type: ', type(input_tensor),
                             'which is not a valid type')

        if input_shape is None and not backend.is_keras_tensor(input_tensor):
            default_size = 224
        elif input_shape is None and backend.is_keras_tensor(input_tensor):
            if backend.image_data_format() == 'channels_first':
                rows = backend.int_shape(input_tensor)[2]
                cols = backend.int_shape(input_tensor)[3]
            else:
                rows = backend.int_shape(input_tensor)[1]
                cols = backend.int_shape(input_tensor)[2]

            if rows == cols and rows in [96, 128, 160, 192, 224]:
                default_size = rows
            else:
                default_size = 224

    # If input_shape is None and no input_tensor
    elif input_shape is None:
        default_size = 224

    # If input_shape is not None, assume default size
    else:
        if backend.image_data_format() == 'channels_first':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            rows = input_shape[0]
            cols = input_shape[1]

        if rows == cols and rows in [96, 128, 160, 192, 224]:
            default_size = rows
        else:
            default_size = 224

    input_shape = _obtain_input_shape(input_shape,
                                      default_size=default_size,
                                      min_size=32,
                                      data_format=backend.image_data_format(),
                                      require_flatten=include_top,
                                      weights=weights)

    if backend.image_data_format() == 'channels_last':
        row_axis, col_axis = (0, 1)
    else:
        row_axis, col_axis = (1, 2)
    rows = input_shape[row_axis]
    cols = input_shape[col_axis]

    if weights == 'imagenet':
        if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]:
            raise ValueError('If imagenet weights are being loaded, '
                             'alpha can be one of `0.35`, `0.50`, `0.75`, '
                             '`1.0`, `1.3` or `1.4` only.')

        if rows != cols or rows not in [96, 128, 160, 192, 224]:
            rows = 224
            warnings.warn('`input_shape` is undefined or non-square, '
                          'or `rows` is not in [96, 128, 160, 192, 224].'
                          ' Weights for input shape (224, 224) will be'
                          ' loaded as the default.')

    if input_tensor is None:
        img_input = layers.Input(shape=input_shape)
    else:
        if not backend.is_keras_tensor(input_tensor):
            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1

    first_block_filters = _make_divisible(32 * alpha, 8)

    print("padding size", correct_pad(backend, img_input, 3))
    x = layers.ZeroPadding2D(padding=(1,1),
                             name='Conv1_pad')(img_input)
   
    conv = layers.Conv2D(first_block_filters,
                      kernel_size=3,
                      strides=(2, 2),
                      padding='valid',
                      use_bias=False,
                      name='Conv1')


    x = conv(x)
    kernel_weight = conv.get_weights()[0]
    if enclave_list is not None:
        enclave_list.append(["conv2d", img_input.shape, 
                             x.shape, (3, 3), (2, 2), (1, 1), 
                             kernel_weight, None])
    
    x = Activation(act_mode='bnrelu',
                   privacy=privacy,
                   epsilon=1e-3,
                   momentum=0.999,
                   sgxutils=sgxutils,
                   name='bn_Conv1')(x, training=True)
    
    if enclave_list is not None:
        enclave_list.append(["bn", x.shape, 
                             1, 1e-3, 0.999])               


      
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=16, alpha=alpha, stride=1,
                            expansion=1, block_id=0, enclave_list=enclave_list)
    
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=24, alpha=alpha, stride=2,
                            expansion=6, block_id=1, enclave_list=enclave_list)
    

    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=24, alpha=alpha, stride=1,
                            expansion=6, block_id=2, enclave_list=enclave_list)
    
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=32, alpha=alpha, stride=2,
                            expansion=6, block_id=3, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=32, alpha=alpha, stride=1,
                            expansion=6, block_id=4, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=32, alpha=alpha, stride=1,
                            expansion=6, block_id=5, enclave_list=enclave_list)
    
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=64, alpha=alpha, stride=2,
                            expansion=6, block_id=6, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=64, alpha=alpha, stride=1,
                            expansion=6, block_id=7, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=64, alpha=alpha, stride=1,
                            expansion=6, block_id=8, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=64, alpha=alpha, stride=1,
                            expansion=6, block_id=9, enclave_list=enclave_list)
    
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=96, alpha=alpha, stride=1,
                            expansion=6, block_id=10, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=96, alpha=alpha, stride=1,
                            expansion=6, block_id=11, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=96, alpha=alpha, stride=1,
                            expansion=6, block_id=12, enclave_list=enclave_list)
    
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=160, alpha=alpha, stride=2,
                            expansion=6, block_id=13, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=160, alpha=alpha, stride=1,
                            expansion=6, block_id=14, enclave_list=enclave_list)
    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=160, alpha=alpha, stride=1,
                            expansion=6, block_id=15, enclave_list=enclave_list)

    x = _inverted_res_block(x, sgxutils=sgxutils, privacy=privacy, filters=320, alpha=alpha, stride=1,
                            expansion=6, block_id=16, enclave_list=enclave_list)
    
    
    
    # no alpha applied to last conv as stated in the paper:
    # if the width multiplier is greater than 1 we
    # increase the number of output channels
    if alpha > 1.0:
        last_block_filters = _make_divisible(1280 * alpha, 8)
    else:
        last_block_filters = 1280


    conv_last = layers.Conv2D(last_block_filters,
                      kernel_size=1,
                      use_bias=False,
                      name='Conv_1')
    
    in_size = x.shape
    x = conv_last(x)

    kernel_weight = conv_last.get_weights()[0]
    if enclave_list is not None:
        enclave_list.append(["conv2d", in_size, 
                             x.shape, (1, 1), (1, 1), (0, 0), 
                             kernel_weight, None])


    x = Activation(act_mode='bnrelu',
                   privacy=privacy,
                   epsilon=1e-3,
                   momentum=0.999,
                   sgxutils=sgxutils,
                   name='Conv_1_bn')(x)
    if enclave_list is not None:
        enclave_list.append(["bn", x.shape, 
                             1, 1e-3, 0.999]) 
    
    
    if include_top:
        in_size = x.shape
        x = layers.GlobalAveragePooling2D()(x)
        if enclave_list is not None:
            enclave_list.append(["pool", in_size, (1, 1, 1, x.shape[1]), (7, 7), (1, 1), (0, 0), 1])
        in_size = (1, 1, 1, x.shape[1])
        dense_layer = layers.Dense(classes, activation='linear',
                         use_bias=True, name='dense')
        x = dense_layer(x)
        if enclave_list is not None:
            enclave_list.append(["linear", in_size, x.shape, dense_layer.get_weights()[0], dense_layer.get_weights()[1]])
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D()(x)
    

    
    # Ensure that the model takes o account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = keras_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input
    
    # Create model.
    model = models.Model(inputs, x,
                         name='mobilenetv2_%0.2f_%s' % (alpha, rows))

    

    return model


def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, privacy=False, sgxutils=None, enclave_list=None):
    channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1

    in_channels = backend.int_shape(inputs)[channel_axis]
    pointwise_conv_filters = int(filters * alpha)
    pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
    x = inputs
    prefix = 'block_{}_'.format(block_id)
    print("tensor shape in ", inputs.shape)
    if enclave_list is not None:
        enclave_list.append(["inv_start"])

    if block_id:
        # Expand
        inv_conv = layers.Conv2D(expansion * in_channels,
                          kernel_size=1,
                          padding='same',
                          use_bias=False,
                          activation=None,
                          name=prefix + 'expand')

        x = inv_conv(x)
        if enclave_list is not None:
            kernel_weight = inv_conv.get_weights()[0]
            enclave_list.append(["conv2d", inputs.shape, 
                                 x.shape, (1, 1), (1, 1), (0, 0), 
                                 kernel_weight, None])

        x = Activation(act_mode='bnrelu',
                       privacy=privacy,
                       axis=channel_axis,
                       epsilon=1e-3,
                       momentum=0.999,
                       sgxutils=sgxutils,
                       name=prefix + 'expand_BN')(x, training=True)

        if enclave_list is not None:
            enclave_list.append(["bn", x.shape, 
                                 1, 1e-3, 0.999])  
    else:
        prefix = 'expanded_conv_'
        # Depthwise

    in_shape = x.shape
    padding_check=(0, 0)
    if stride == 2:
        padding_check=(1, 1) #
        print(correct_pad(backend, x, 3))
        print("==============================================")
        print(correct_pad(backend, x, 3))

        x = layers.ZeroPadding2D(padding=padding_check,
                                 name=prefix + 'pad')(x)
    print("______________________________________________")
    
    dep = layers.DepthwiseConv2D(kernel_size=3,
                               strides=stride,
                               activation=None,
                               use_bias=False,
                               padding='same' if stride == 1 else 'valid',
                               name=prefix + 'depthwise')
    x = dep(x)
    
    if enclave_list is not None:
        enclave_list.append(["depthwise", in_shape, 
                            x.shape, (3, 3), (stride, stride), padding_check, 
                            dep.get_weights()[0], None])
    
        print(in_shape, x.shape, stride, padding_check)

    print("______________________________________________")

    x =  Activation(act_mode='bnrelu',
                    privacy=privacy,
                    axis=channel_axis,
                    epsilon=1e-3,
                    momentum=0.999,
                    sgxutils=sgxutils,
                    name=prefix + 'depthwise_BN')(x)
   
    if enclave_list is not None:
        enclave_list.append(["bn", x.shape, 
                             1, 1e-3, 0.999])      
    # Project
    in_shape = x.shape
    inv_conv = layers.Conv2D(pointwise_filters,
                      kernel_size=1,
                      padding='same',
                      use_bias=False,
                      activation=None,
                      name=prefix + 'project')
    x = inv_conv(x)

    if enclave_list is not None:
            kernel_weight = inv_conv.get_weights()[0]
            enclave_list.append(["conv2d", in_shape, 
                                 x.shape, (1, 1), (1, 1), (0, 0), 
                                 kernel_weight, None])                  

    mode_num = 1
    if in_channels == pointwise_filters and stride == 1:
        # use skip connection
        mode = 'bnadd'
        skip_input = inputs
        mode_num = 2
    else:
        mode = 'bn'
        skip_input = None
        mode_num = 0

    x = Activation(act_mode=mode,
                   privacy=privacy, 
                   axis=channel_axis,
                   epsilon=1e-3,
                   momentum=0.999,
                   sgxutils=sgxutils,
                   name=prefix + 'project_BN')(x, skip_input = skip_input)
                   
    if enclave_list is not None:
        enclave_list.append(["bn", x.shape, 
                             mode_num, 1e-3, 0.999])
    if enclave_list is not None:
        enclave_list.append(["inv_end"])
    return x
