import sys
sys.path.append("../")

import tensorflow as tf
import numpy as np
from keras.layers import *
from keras import Model
from keras import layers as Layers

val_target_size = tf.cast(368, dtype=tf.int32)
std = [0.229, 0.224, 0.225]
std_tensor = tf.constant(np.array(std), dtype=tf.float32)
eigval = tf.constant(np.array([55.46, 4.794, 1.148]), dtype=tf.float32)
eigvec = tf.constant(np.array([[-0.5836, -0.6948, 0.4203], [-0.5808, -0.0045, -0.8140], [-0.5675, 0.7192, 0.4009]]), dtype=tf.float32)


class ResBlock(Model):
    def __init__(self, channels, stride=1):
        super(ResBlock, self).__init__(name='ResBlock')
        self.flag = (stride != 1)
        self.conv1 = Conv2D(channels, 3, stride, padding='same')
        self.bn1 = tf.keras.layers.experimental.SyncBatchNormalization()
        self.conv2 = Conv2D(channels, 3, padding='same')
        self.bn2 = tf.keras.layers.experimental.SyncBatchNormalization()
        self.relu = ReLU()
        if self.flag:
            self.bn3 = tf.keras.layers.experimental.SyncBatchNormalization()
            self.conv3 = Conv2D(channels, 1, stride)

    def call(self, x, training=None):
        x1 = self.conv1(x)
        x1 = self.bn1(x1, training=training)
        x1 = self.relu(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1, training=training)
        if self.flag:
            x = self.conv3(x)
            x = self.bn3(x, training=training)
        x1 = Layers.add([x, x1])
        x1 = self.relu(x1)
        return x1


class ResNet(Model):
    def __init__(self):
        super(ResNet, self).__init__(name='ResNet')
        self.random_flip = tf.keras.layers.RandomFlip(mode="horizontal")

        self.conv1 = Conv2D(64, 7, 2, padding='same')
        self.bn = tf.keras.layers.experimental.SyncBatchNormalization()
        self.relu = ReLU()
        self.mp1 = MaxPooling2D(3, 2)

        self.conv2_1 = ResBlock(64)
        self.conv2_2 = ResBlock(64)
        self.conv2_3 = ResBlock(64)

        self.conv3_1 = ResBlock(128, 2)
        self.conv3_2 = ResBlock(128)
        self.conv3_3 = ResBlock(128)
        self.conv3_4 = ResBlock(128)

        self.conv4_1 = ResBlock(256, 2)
        self.conv4_2 = ResBlock(256)
        self.conv4_3 = ResBlock(256)
        self.conv4_4 = ResBlock(256)
        self.conv4_5 = ResBlock(256)
        self.conv4_6 = ResBlock(256)

        self.conv5_1 = ResBlock(512, 2)
        self.conv5_2 = ResBlock(512)
        self.conv5_3 = ResBlock(512)

        self.pool = GlobalAveragePooling2D()
        self.fc1 = Dense(1000)

    def call(self, x, training=None):
        if training:
            x = self.random_flip(x, training=training)
            alpha = tf.random.normal([3], mean=0.0, stddev=0.1)
            offset = tf.experimental.numpy.dot(eigvec * alpha, eigval)
            x = tf.clip_by_value(x + offset, 0.0, 255.0)

        x = tf.keras.applications.imagenet_utils.preprocess_input(x, mode='torch')
        x = x * std_tensor

        x = self.conv1(x)
        x = self.bn(x, training=training)
        x = self.relu(x)
        x = self.mp1(x)

        x = self.conv2_1(x, training=training)
        x = self.conv2_2(x, training=training)
        x = self.conv2_3(x, training=training)

        x = self.conv3_1(x, training=training)
        x = self.conv3_2(x, training=training)
        x = self.conv3_3(x, training=training)
        x = self.conv3_4(x, training=training)

        x = self.conv4_1(x, training=training)
        x = self.conv4_2(x, training=training)
        x = self.conv4_3(x, training=training)
        x = self.conv4_4(x, training=training)
        x = self.conv4_5(x, training=training)
        x = self.conv4_6(x, training=training)

        x = self.conv5_1(x, training=training)
        x = self.conv5_2(x, training=training)
        x = self.conv5_3(x, training=training)

        x = self.pool(x)
        x = self.fc1(x)
        return x


def build_resnet_batchnorm(optimizer, loss, metrics):
    model = ResNet()
    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
    )
    return model
