"""VGG16 model for Keras.
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from keras_applications import get_submodules_from_kwargs
from keras_applications import imagenet_utils
from keras_applications.imagenet_utils import decode_predictions
from keras_applications.imagenet_utils import _obtain_input_shape

import keras.backend as backend
import keras.layers as layers
import keras.models as models
import keras.utils as keras_utils

preprocess_input = imagenet_utils.preprocess_input

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
                'releases/download/v0.1/'
                'vgg16_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
                       'releases/download/v0.1/'
                       'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')


def VGG16(include_top=True,
          weights='imagenet',
          input_tensor=None,
          input_shape=None,
          pooling=None,
          classes=1000,
          sgxutils=None,
          enclave_list=None,
          **kwargs):
   
    if not (weights in {'imagenet', None} or os.path.exists(weights)):
        raise ValueError('The `weights` argument should be either '
                         '`None` (random initialization), `imagenet` '
                         '(pre-training on ImageNet), '
                         'or the path to the weights file to be loaded.')

    if weights == 'imagenet' and include_top and classes != 1000:
        raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
                         ' as true, `classes` should be 1000')
    # Determine proper input shape
    input_shape = _obtain_input_shape(input_shape,
                                      default_size=224,
                                      min_size=32,
                                      data_format=backend.image_data_format(),
                                      require_flatten=include_top,
                                      weights=weights)

    if input_tensor is None:
        img_input = layers.Input(shape=input_shape)
    else:
        if not backend.is_keras_tensor(input_tensor):
            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor
    # Block 1


    conv_layer = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv1')

    conv_in_shape  = img_input.shape
    x = conv_layer(img_input)
    conv_out_shape = x.shape

    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#
    conv_layer = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv2')
    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]
    
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
    pool_out_size = x.shape

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
        enclave_list.append(["pool", conv_out_shape, pool_out_size, (2, 2), (2, 2), (0, 0), 0])
    #------------------------------------------------------#


    # Block 2
    conv_layer = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv1')


    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#

    conv_layer = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv2')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
    pool_out_size = x.shape

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
        enclave_list.append(["pool", conv_out_shape, pool_out_size, (2, 2), (2, 2), (0, 0), 0])
    #------------------------------------------------------#


    # Block 3
    conv_layer = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv1')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]


    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#

    conv_layer = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv2')
    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#

    conv_layer = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv3')


    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
    pool_out_size = x.shape

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
        enclave_list.append(["pool", conv_out_shape, pool_out_size, (2, 2), (2, 2), (0, 0), 0])
    #------------------------------------------------------#

    # Block 4
    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv1')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]


    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#

    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv2')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#

    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv3')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
    pool_out_size=x.shape
    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
        enclave_list.append(["pool", conv_out_shape, pool_out_size, (2, 2), (2, 2), (0, 0), 0])
    #------------------------------------------------------#

    # Block 5
    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv1')
    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]
    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#
    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv2')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]
    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
    #------------------------------------------------------#
    
    conv_layer = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv3')

    conv_in_shape = x.shape
    x = conv_layer(x)
    conv_out_shape = x.shape
    kernel_data = conv_layer.get_weights()[0]
    bias_data   = conv_layer.get_weights()[1]

    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
    pool_out_size=x.shape

    #------------------------------------------------------#
    if sgxutils is not None and enclave_list is not None:
        enclave_list.append(["conv2d", 
                             conv_in_shape, 
                             conv_out_shape, (3, 3), (1, 1), (1, 1), 
                             kernel_data, bias_data])
        enclave_list.append(["relu", conv_out_shape])
        enclave_list.append(["pool", conv_out_shape, pool_out_size, (2, 2), (2, 2), (0, 0), 0])
    #------------------------------------------------------#
    
    if include_top:
        # Classification block
        x = layers.Flatten(name='flatten')(x)
        dense0 = layers.Dense(4096, activation='relu', name='fc1')
        x = dense0(x)
        dense1 = layers.Dense(4096, activation='relu', name='fc2')
        x = dense1(x)

        dense2 = layers.Dense(classes, activation='softmax', name='predictions')
        x = dense2(x)
        if enclave_list is not None:
            enclave_list.append(["linear", [1, 7, 7, 512], [1, 4096], dense0.get_weights()[0], dense0.get_weights()[1]])
            enclave_list.append(["relu", [1, 1, 1, 4096]])
            enclave_list.append(["linear", [1, 1, 1, 4096], [1, 4096], dense1.get_weights()[0], dense1.get_weights()[1]])
            enclave_list.append(["relu", [1, 1, 1, 4096]])
            enclave_list.append(["linear", [1, 1, 1, 4096], [1, 1000], dense2.get_weights()[0], dense2.get_weights()[1]])

    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D()(x)




    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = keras_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input
    # Create model.
    model = models.Model(inputs, x, name='vgg16')

    # Load weights.
    if weights == 'imagenet':
        if include_top:
            weights_path = keras_utils.get_file(
                'vgg16_weights_tf_dim_ordering_tf_kernels.h5',
                WEIGHTS_PATH,
                cache_subdir='models',
                file_hash='64373286793e3c8b2b4e3219cbf3544b')
        else:
            weights_path = keras_utils.get_file(
                'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                WEIGHTS_PATH_NO_TOP,
                cache_subdir='models',
                file_hash='6d6bbae143d832006294945121d1f1fc')
        model.load_weights(weights_path)
        if backend.backend() == 'theano':
            keras_utils.convert_all_kernels_in_model(model)
    elif weights is not None:
        model.load_weights(weights)

    return model