from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from keras.layers import *
from keras import layers
from keras.models import Model, Sequential
from keras import backend as K
from keras.engine import get_source_inputs
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.engine.topology import load_weights_from_hdf5_group_by_name, h5py
import tensorflow as tf
from python.slalom.Activation import ResNetActivation
from python.slalom.Activation import ResNetBottom
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'


class ResNetBlock(Layer):
    def __init__(self, kernel_size, filters, stage, block, identity=False, strides=(2, 2),
                 path1=None, path2=None, merge_act=None, quantize=False, bits_w=8, bits_x=8,
                 slalom=False, slalom_integrity=False, slalom_privacy=False, use_bn=False, basic=False, sgxutils=None, privacy=None, enclave_list=None, **kwargs):
        super(ResNetBlock, self).__init__(**kwargs)

        self.quantize = quantize
        self.slalom = slalom
        self.slalom_integrity = slalom_integrity
        self.slalom_privacy = slalom_privacy
        self.bits_w = bits_w
        self.bits_x = bits_x
        self.range_w = 2 ** bits_w
        self.range_x = 2 ** bits_x

        self.kernel_size = kernel_size
        self.filters = filters
        self.strides = strides
        self.identity = identity

        self.basic = basic
        self.stage = stage
        self.block = block
        self.conv_name_base = 'res' + str(stage) + block + '_branch'
        self.bn_name_base = 'bn' + str(stage) + block + '_branch'
        self.use_bn = use_bn

        self.path1 = [] if path1 is None else path1 
        self.path2 = [] if path2 is None else path2 
        self.merge_act = merge_act
        self.privacy=privacy
        self.sgxutils=sgxutils
        self.bottom = None
        self.enclave_list=enclave_list
        self.not_built = True
    def create_layers(self, input_shape, privacy, sgxutils):
        if K.image_data_format() == 'channels_last':
            bn_axis = 3
        else:
            bn_axis = 1

        if self.basic:
            self.create_layers_basic(input_shape)
            return

        filters1, filters2, filters3 = self.filters



        shape = input_shape
        conv_stride = self.strides
        conv_layer = None
        if self.identity:
            conv_layer = Conv2D(filters1, (1, 1), name=self.conv_name_base + '2a', input_shape=shape, use_bias=False)
            self.path1.append(conv_layer)
            conv_stride = (1, 1)
        else:
            conv_layer = Conv2D(filters1, (1, 1), strides=self.strides, name=self.conv_name_base + '2a', input_shape=shape, use_bias=False)
            self.path1.append(conv_layer)

        conv_in_shape = shape

        
        self.path1[-1].build(shape)
        shape = self.path1[-1].compute_output_shape(shape)
        resblock_shape = list(shape)
        resblock_shape[3] = filters3
        if self.not_built and self.enclave_list is not None:
            self.enclave_list.append(["resblock_start", conv_in_shape, resblock_shape, self.strides, self.identity])
        kernel_weight = conv_layer.get_weights()[0]
        if self.not_built and self.enclave_list is not None:
            self.enclave_list.append(["conv2d", conv_in_shape, shape, (1, 1), conv_stride, (0, 0), kernel_weight, None])
            self.enclave_list.append(["bn", shape, 1, 1e-3, 0.999])
            self.enclave_list.append(["relu", shape])


        #self.path1.append(BatchNormalization(axis=bn_axis, name=self.bn_name_base + '2a', input_shape=shape))
        #self.path1[-1].build(shape)
        #self.path1.append(Activation('relu'))
        #self.path1[-1].build(shape)
        #####
        _,_,_,bias_shape = shape 
        self.path1.append(ResNetActivation(act_mode='bnzerorelu',
                                           privacy=self.privacy,
                                           sgxutils=self.sgxutils,
                                           use_bias=True,
                                           bias_shape=bias_shape))
        
        self.path1[-1].build(shape)

        #####
        conv_in_shape = shape

        conv_layer = Conv2D(filters2, self.kernel_size,
                                padding='same', name=self.conv_name_base + '2b', input_shape=shape, use_bias=False)

        self.path1.append(conv_layer)
        self.path1[-1].build(shape)
        shape = self.path1[-1].compute_output_shape(shape)
        _,_,_,bias_shape = shape 
        kernel_weight = conv_layer.get_weights()[0]
        if self.not_built and self.enclave_list is not None:
            self.enclave_list.append(["conv2d", conv_in_shape, shape, (self.kernel_size, self.kernel_size), (1, 1), (1, 1), kernel_weight, None])
            self.enclave_list.append(["bn", shape, 1, 1e-3, 0.999])
            self.enclave_list.append(["relu", shape])

        #self.path1.append(BatchNormalization(axis=bn_axis, name=self.bn_name_base + '2b', input_shape=shape))
        #self.path1[-1].build(shape)
        #self.path1.append(Activation('relu'))
        #self.path1[-1].build(shape)
        self.path1.append(ResNetActivation(act_mode='bnzerorelu',
                                           privacy=self.privacy,
                                           sgxutils=self.sgxutils, 
                                           use_bias=True,
                                           bias_shape=bias_shape))
        self.path1[-1].build(shape)
        conv_in_shape = shape

        conv_layer = Conv2D(filters3, (1, 1), name=self.conv_name_base + '2c', input_shape=shape, use_bias=False)
        self.path1.append(conv_layer)
        self.path1[-1].build(shape)
        shape = self.path1[-1].compute_output_shape(shape)
        _,_,_,bias_shape = shape 
        kernel_weight = conv_layer.get_weights()[0]

        if self.not_built and self.enclave_list is not None:
            self.enclave_list.append(["conv2d", conv_in_shape, shape, (1, 1), (1, 1), (0, 0), kernel_weight, None])
            self.enclave_list.append(["bn", shape, 1, 1e-3, 0.999])
            self.enclave_list.append(["relu", shape])


        #self.path1.append(BatchNormalization(axis=bn_axis, name=self.bn_name_base + '2c', input_shape=shape))
        #self.path1[-1].build(shape)
        #self.path1.append(ActivationSp(act_mode='bn',
        #                               privacy=self.privacy,
        #                               sgxutils=self.sgxutils))
        #self.path1[-1].build(shape)

        if not self.identity:
            shape = input_shape
            self.path2.append(Conv2D(filters3, (1, 1), strides=self.strides,
                                            name=self.conv_name_base + '1', input_shape=shape, use_bias=False))
            self.path2[-1].build(shape)
            shape = self.path2[-1].compute_output_shape(shape)

            #self.path2.append(BatchNormalization(axis=bn_axis, name=self.bn_name_base + '1', input_shape=shape))
            #self.path2[-1].build(shape)

        self.bottom = ResNetBottom(right_norm=not self.identity, privacy=privacy, sgxutils=sgxutils, use_bias=True, bias_shape=bias_shape)
        self.bottom.build(shape)

        #self.merge_act = Activation('relu')
        #self.merge_act.build(shape)
        if self.not_built and self.enclave_list is not None:
            self.enclave_list.append(["resblock_compl"])
        self.not_built = False
    def copy_data(self, other):
        assert(len(other.path1) == len(other.path1))
        assert(len(other.path2) == len(other.path2))
        assert(self.identity    == other.identity)
        # copying first conv
        kernel_data = other.path1[0].get_weights()[0]
        bias_data = other.path1[0].get_weights()[1]
        self.path1[0].set_weights([kernel_data])
        self.path1[1].set_weights([bias_data])
        
        kernel_data = other.path1[2].get_weights()[0]
        bias_data = other.path1[2].get_weights()[1]
        self.path1[2].set_weights([kernel_data])
        self.path1[3].set_weights([bias_data])
        
        kernel_data = other.path1[4].get_weights()[0]
        bias_left = other.path1[4].get_weights()[1]
        self.path1[4].set_weights([kernel_data])

        if not self.identity:
            kernel_right = other.path2[0].get_weights()[0]
            bias_right = other.path2[0].get_weights()[1]

            kernel_shape_list = list(kernel_right.shape)
            kernel_shape      = [0, 0, 0, 0]
            
            self.path2[0].set_weights([kernel_right])
            self.bottom.set_weights([bias_left, bias_right])
        else:
            self.bottom.set_weights([bias_left])

    def create_layers_basic(self, input_shape):
        filters, _, _= self.filters

        shape = input_shape
        if self.identity:
            self.path1.append(Conv2D(filters, self.kernel_size, name=self.conv_name_base + '2a', input_shape=shape, padding='same'))
        else:
            self.path1.append(Conv2D(filters, self.kernel_size, strides=self.strides, name=self.conv_name_base + '2a', input_shape=shape, padding='same'))

        self.path1[-1].build(shape)
        shape = self.path1[-1].compute_output_shape(shape)
        self.path1.append(Activation('relu'))
        self.path1[-1].build(shape)

        self.path1.append(Conv2D(filters, self.kernel_size,
                                padding='same', name=self.conv_name_base + '2b', input_shape=shape))
        self.path1[-1].build(shape)
        shape = self.path1[-1].compute_output_shape(shape)

        if not self.identity and self.strides != (1, 1):
            shape = input_shape
            self.path2.append(Conv2D(filters, (1, 1), strides=self.strides,
                                            name=self.conv_name_base + '1', input_shape=shape))
            self.path2[-1].build(shape)
            shape = self.path2[-1].compute_output_shape(shape)

        #self.merge_act = Activation('relu')
        #self.merge_act.build(shape)
        self.bottom.build(shape)
    def build(self, input_shape):
        super(ResNetBlock, self).build(input_shape)

        if self.path1 or self.path2:
            shape = input_shape
            for l in self.path1:
                l.build(shape)
                shape = l.compute_output_shape(shape)

            shape = input_shape
            for l in self.path2:
                l.build(shape)                
                shape = l.compute_output_shape(shape)
            
            self.merge_act.build(shape)

        else:

            self.create_layers(input_shape, privacy=self.privacy, sgxutils=self.sgxutils)

    def compute_output_shape(self, input_shape):
        shape = input_shape
        for l in self.path1:
            shape = l.compute_output_shape(shape)
        return self.bottom.compute_output_shape(shape)

    def get_layers(self):
        if self.use_bn:
            #layers = [l for l in self.path1 + self.path2 if isinstance(l, Conv2D) or isinstance(l, BatchNormalization)]
            layers = [l for l in self.path1 + self.path2]
        else:
            layers = [l for l in self.path1 + self.path2 if not isinstance(l, BatchNormalization)]
        layers.append(self.merge_act)
        return layers
        
    def call(self, inputs):
        out1 = inputs
        for l in self.path1:
            out1 = l(out1)

        out2 = inputs
        for l in self.path2:
            out2 = l(out2)

        if self.quantize and not self.path2:
            out2 *= 2**self.bits_w

        merge = self.bottom(out1, right=out2)
        return merge
 
    def get_config(self):
        config = {
            'kernel_size': self.kernel_size,
            'filters': self.filters,
            'strides': self.strides,
            'identity': self.identity,
            'conv_name_base': self.conv_name_base,
            'bn_name_base': self.bn_name_base,
            'basic': self.basic
        }

        base_config = super(ResNetBlock, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def identity_block(input_tensor, kernel_size, filters, stage, block):
    """The identity block is the block that has no conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
    # Returns
        Output tensor for the block.
    """
    filters1, filters2, filters3 = filters
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size,
               padding='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x


def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    """A block that has a conv layer at shortcut.
    # Arguments
        input_tensor: input tensor
        kernel_size: default 3, the kernel size of middle conv layer at main path
        filters: list of integers, the filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
        strides: Strides for the first conv layer in the block.
    # Returns
        Output tensor for the block.
    Note that from stage 3,
    the first conv layer at main path is with strides=(2, 2)
    And the shortcut should have strides=(2, 2) as well
    """
    filters1, filters2, filters3 = filters
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size, padding='same',
               name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)

    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x


def ResNet50(include_top=True, weights='imagenet',
             input_tensor=None, input_shape=None,
             pooling=None, classes=1000, layers=50, privacy=False, sgxutils=None, enclave_list=None):

    """Instantiates the ResNet50 architecture.
    Optionally loads weights pre-trained on ImageNet.
    Note that the data format convention used by the model is
    the one specified in your Keras config at `~/.keras/keras.json`.
    When using TensorFlow, for best performance you should
    set `"image_data_format": "channels_last"` in the config.
    # Arguments
        include_top: whether to include the fully-connected
            layer at the top of the network.
        weights: one of `None` (random initialization),
              'imagenet' (pre-training on ImageNet),
              or the path to the weights file to be loaded.
        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
            to use as image input for the model.
        input_shape: optional shape tuple, only to be specified
            if `include_top` is False (otherwise the input shape
            has to be `(224, 224, 3)` (with `channels_last` data format)
            or `(3, 224, 224)` (with `channels_first` data format).
            It should have exactly 3 inputs channels,
            and width and height should be no smaller than 197.
            E.g. `(200, 200, 3)` would be one valid value.
        pooling: Optional pooling mode for feature extraction
            when `include_top` is `False`.
            - `None` means that the output of the model will be
                the 4D tensor output of the
                last convolutional layer.
            - `avg` means that global average pooling
                will be applied to the output of the
                last convolutional layer, and thus
                the output of the model will be a 2D tensor.
            - `max` means that global max pooling will
                be applied.
        classes: optional number of classes to classify images
            into, only to be specified if `include_top` is True, and
            if no `weights` argument is specified.
    # Returns
        A Keras model instance.
    # Raises
        ValueError: in case of invalid argument for `weights`,
            or invalid input shape.
    """
    if not (weights in {'imagenet', None} or os.path.exists(weights)):
        raise ValueError('The `weights` argument should be either '
                         '`None` (random initialization), `imagenet` '
                         '(pre-training on ImageNet), '
                         'or the path to the weights file to be loaded.')

    if weights == 'imagenet' and include_top and classes != 1000:
        raise ValueError('If using `weights` as imagenet with `include_top`'
                         ' as true, `classes` should be 1000')

    assert layers in [18, 34, 50, 101, 152]
    use_bn = (layers == 50)
    basic = (layers in [18, 34])

    if layers == 18:
        num_layers = [2, 2, 2, 2]
    elif layers == 34:
        num_layers = [3, 4, 6, 3]
    elif layers == 50:
        num_layers = [3, 4, 6, 3]
    elif layers == 101:
        num_layers = [3, 4, 23, 3]
    elif layers == 152:
        num_layers = [3, 8, 36, 3]

    # Determine proper input shape
    input_shape = _obtain_input_shape(input_shape,
                                      default_size=224,
                                      min_size=197,
                                      data_format=K.image_data_format(),
                                      require_flatten=include_top,
                                      weights=weights)

    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            img_input = Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    if basic:
        x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv1')(img_input)
        x = Activation('relu')(x)
        x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
        x = ResNetBlock(3, [64, 64, 256], stage=2, block='a', use_bn=use_bn, basic=basic)(x)
    else:
        x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
        input_shape = x.shape
        conv_layer = Conv2D(64, (7, 7), strides=(2, 2), padding='valid', name='conv1', use_bias=False) 
        x = conv_layer(x)
        kernel_weight = conv_layer.get_weights()[0]
        #bias_weight   = conv_layer.get_weights()[1]
        print(kernel_weight.shape)
        if enclave_list is not None:
            enclave_list.append(["conv2d", input_shape, x.shape, (7, 7), (2, 2), (0, 0), kernel_weight, None, None])
            enclave_list.append(["bn", x.shape, 1, 1e-3, 0.999])
            enclave_list.append(["relu", x.shape])
        
        #print("========================")
        #x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
        #x = Activation('relu')(x)
        #x = MaxPooling2D((3, 3), strides=(2, 2))(x)
        _, _, _, bias_shape = x.shape
        pool_in_shape = x.shape
        x = ResNetActivation(act_mode='bnrelupool', 
                             privacy=privacy, 
                             sgxutils=sgxutils,
                             pool_window=(3, 3),
                             strides=(2, 2),
                             use_bias=True,
                             bias_shape=bias_shape)(x, training=True)

        if enclave_list is not None:
            enclave_list.append(["pool", pool_in_shape, x.shape, (3, 3), (2, 2), (0, 0), 0])
        print("========================")
        print(x.shape)
        x = ResNetBlock(3, [64, 64, 256], stage=2, block='a', strides=(1, 1), use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)
        print(x.shape)
    for i in range(num_layers[0] - 1):
        x = ResNetBlock(3, [64, 64, 256], stage=2, block=chr(ord('b') + i), identity=True, use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)

    x = ResNetBlock(3, [128, 128, 512], stage=3, block='a', use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)
    for i in range(num_layers[1] - 1):
        x = ResNetBlock(3, [128, 128, 512], stage=3, block=chr(ord('b') + i), identity=True, use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)

    x = ResNetBlock(3, [256, 256, 1024], stage=4, block='a', use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)
    for i in range(num_layers[2] - 1):
        x = ResNetBlock(3, [256, 256, 1024], stage=4, block=chr(ord('b') + i), identity=True, use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)

    x = ResNetBlock(3, [512, 512, 2048], stage=5, block='a', use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)
    for i in range(num_layers[3] - 1):
        x = ResNetBlock(3, [512, 512, 2048], stage=5, block=chr(ord('b') + i), identity=True, use_bn=use_bn, basic=basic, privacy=privacy, sgxutils=sgxutils, enclave_list=enclave_list)(x)
    pool_size = x.shape
    if basic:
        x = GlobalAveragePooling2D()(x)
    else:
        x = AveragePooling2D((7, 7), name='avg_pool')(x)
        x = Flatten()(x)
    if enclave_list is not None:
        enclave_list.append(["pool", pool_size, (1, 1, 1, 2048), (7, 7), (1, 1), (0, 0), 1])
    
    linear_in_shape = [1, 1, 1, 2048]
    linear_out_shape = [1, 1000]
    dense_layer= Dense(classes, activation='softmax', name='fc1000', use_bias=False)
    x = dense_layer(x)
    if enclave_list is not None:
        enclave_list.append(["linear", linear_in_shape, linear_out_shape, dense_layer.get_weights()[0], np.zeros(1000).astype(np.float32)])
        #enclave_list.append(["relu",   linear_out_shape])
    _, bias_shape = x.shape

    x = ResNetActivation(act_mode='bias_add', 
                         privacy=True, 
                         sgxutils=sgxutils,
                         use_bias=True,
                         bias_shape=bias_shape)(x)


    '''
    '''
    '''
    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    '''
    if input_tensor is not None:
        inputs = get_source_inputs(input_tensor)
    else:
        inputs = img_input
    # Create model.
    model = Model(inputs, x, name='resnet50')

    # load weights
    if weights == 'imagenet' and layers == 50:
        if include_top:
            weights_path = get_file(
                'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
                WEIGHTS_PATH,
                cache_subdir='models',
                md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
        else:
            weights_path = get_file(
                'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
                WEIGHTS_PATH_NO_TOP,
                cache_subdir='models',
                md5_hash='a268eb855778b3df3c7506639542a6af')

        with h5py.File(weights_path, mode='r') as f:
            if 'layer_names' not in f.attrs and 'model_weights' in f:
                f = f['model_weights']

            import itertools
            all_layers = [[l] if not isinstance(l, ResNetBlock) else l.get_layers() for l in model.layers]
            all_layers = list(itertools.chain.from_iterable(all_layers))
            load_weights_from_hdf5_group_by_name(f, all_layers)

        if K.backend() == 'theano':
            layer_utils.convert_all_kernels_in_model(model)
    
    return model
