from typing import Tuple
import tensorflow as tf


class BasicBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=stride,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=1,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                       kernel_size=(1, 1),
                                                       strides=stride))
            self.downsample.add(tf.keras.layers.BatchNormalization())
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))

        return output


def make_basic_block_layer(filter_num, blocks, stride=1):
    res_block = tf.keras.Sequential()
    res_block.add(BasicBlock(filter_num, stride=stride))

    for _ in range(1, blocks):
        res_block.add(BasicBlock(filter_num, stride=1))

    return res_block


class ResNet18(tf.keras.Model):
    """"""
    def __init__(self, image_size: Tuple[int]):
        super(ResNet18, self).__init__()
        self.image_size = image_size
        # Smaller kernel size in conv2d and no max pooling according to SimCLR 
        # paper (Appendix B.9), see: https://arxiv.org/pdf/2002.05709.pdf
        if image_size == (32, 32):
            self.conv1 = tf.keras.layers.Conv2D(filters=64,
                                                kernel_size=(3,3),
                                                strides=1,
                                                padding="same")
        else: 
            self.conv1 = tf.keras.layers.Conv2D(filters=64,
                                    kernel_size=(7, 7),
                                    strides=2,
                                    padding="same")

        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                               strides=2,
                                               padding="same")

        self.layer1 = make_basic_block_layer(filter_num=64,
                                             blocks=2)
        self.layer2 = make_basic_block_layer(filter_num=128,
                                             blocks=2,
                                             stride=2)
        self.layer3 = make_basic_block_layer(filter_num=256,
                                             blocks=2,
                                             stride=2)
        self.layer4 = make_basic_block_layer(filter_num=512,
                                             blocks=2,
                                             stride=2)

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        if self.image_size != (32, 32):
            x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.avgpool(x)

        return x
