import tensorflow as tf

from official.vision.image_classification.resnet.resnet_model import \
    _gen_l2_regularizer
from official.vision.image_classification.resnet.resnet_model import \
    BATCH_NORM_DECAY
from official.vision.image_classification.resnet.resnet_model import \
    BATCH_NORM_EPSILON
from official.vision.image_classification.resnet.resnet_model import conv_block
from official.vision.image_classification.resnet.resnet_model import \
    identity_block
from official.vision.image_classification.resnet.resnet_model import \
    initializers
from tensorflow.python.keras import backend
from official.vision.image_classification.resnet import imagenet_preprocessing

layers = tf.keras.layers


def resnet50(num_classes,
             img_input,
             batch_size=None,
             use_l2_regularizer=True,
             rescale_inputs=False):
    """Instantiates the ResNet50 architecture.

    Args:
      num_classes: `int` number of classes for image classification.
      batch_size: Size of the batches for each step.
      use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
      rescale_inputs: whether to rescale inputs from 0 to 1.

    Returns:
        A Keras model instance.
    """
    if rescale_inputs:
        # Hub image modules expect inputs in the range [0, 1]. This rescales these
        # inputs to the range expected by the trained model.
        x = layers.Lambda(
            lambda x: x * 255.0 - backend.constant(
                imagenet_preprocessing.CHANNEL_MEANS,
                shape=[1, 1, 3],
                dtype=x.dtype),
            name='rescale')(
            img_input)
    else:
        x = img_input

    if backend.image_data_format() == 'channels_first':
        x = layers.Permute((3, 1, 2))(x)
        bn_axis = 1
    else:  # channels_last
        bn_axis = 3

    x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
    x = layers.Conv2D(
        64, (7, 7),
        strides=(2, 2),
        padding='valid',
        use_bias=False,
        kernel_initializer='he_normal',
        kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
        name='conv1')(
        x)
    x = layers.BatchNormalization(
        axis=bn_axis,
        momentum=BATCH_NORM_DECAY,
        epsilon=BATCH_NORM_EPSILON,
        name='bn_conv1')(
        x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = conv_block(
        x,
        3, [64, 64, 256],
        stage=2,
        block='a',
        strides=(1, 1),
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [64, 64, 256],
        stage=2,
        block='b',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [64, 64, 256],
        stage=2,
        block='c',
        use_l2_regularizer=use_l2_regularizer)

    x = conv_block(
        x,
        3, [128, 128, 512],
        stage=3,
        block='a',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [128, 128, 512],
        stage=3,
        block='b',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [128, 128, 512],
        stage=3,
        block='c',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [128, 128, 512],
        stage=3,
        block='d',
        use_l2_regularizer=use_l2_regularizer)

    x = conv_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='a',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='b',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='c',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='d',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='e',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [256, 256, 1024],
        stage=4,
        block='f',
        use_l2_regularizer=use_l2_regularizer)

    x = conv_block(
        x,
        3, [512, 512, 2048],
        stage=5,
        block='a',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [512, 512, 2048],
        stage=5,
        block='b',
        use_l2_regularizer=use_l2_regularizer)
    x = identity_block(
        x,
        3, [512, 512, 2048],
        stage=5,
        block='c',
        use_l2_regularizer=use_l2_regularizer)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(
        num_classes,
        kernel_initializer=initializers.RandomNormal(stddev=0.01),
        kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
        bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
        name='fc1000')(
        x)

    # A softmax that is followed by the model loss must be done cannot be done
    # in float16 due to numeric issues. So we pass dtype=float32.
    x = layers.Activation('linear', dtype='float32')(x)
    return x
