import tensorflow as tf

from abstract_model import AbstractModelGenerator
from abstract_model import SequenceModelGenerator

from resnet import resnet50
from vision_transformer import get_a_vision_transformer_layer
from separate_transformer import get_transformer_model
from seq2seq import Translator
from nlp_transformer import Transformer


def get_model_generator(args, input_shape, output_nodes):
    if args.model_type == 'dnn':
        model = DNNModelGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'cnn':
        model = CNNModelGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'resnet':
        model = ResNetModelGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'vit':
        model = VisionTransformerGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'lstm':
        model = LSTMModelGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'transformer':
        model = TransformerGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'sequence_lstm':
        model = SequenceLSTMGenerator(args, input_shape, output_nodes)
    elif args.model_type == 'sequence_transformer':
        model = SequenceTransformerGenerator(args, input_shape, output_nodes)
    else:
        raise ValueError(
            '{0} is not a valid model_type.'.format(args.model_type))
    return model


class DNNModelGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        x = tf.keras.layers.Flatten()(x)
        for i in range(self.depth - 1):
            x = tf.keras.layers.Dense(self.args.n_hidden_nodes, activation='relu')(x)
        x = tf.keras.layers.Dense(all_output_nodes, activation='linear')(x)
        return x


class CNNModelGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        for i in range(self.depth - 2):
            x = tf.keras.layers.Conv2D(self.args.n_hidden_nodes, (3, 3), activation='relu', padding='SAME')(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(2 * self.args.n_hidden_nodes, activation='relu')(x)
        x = tf.keras.layers.Dense(all_output_nodes, activation='linear')(x)
        return x


class ResNetModelGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        return resnet50(all_output_nodes, x, rescale_inputs=True)


class VisionTransformerGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        depth = self.depth - 1
        x = get_a_vision_transformer_layer(
            x, 0, -1, depth, self.input_shape[0])
        for i in range(depth):
            x = get_a_vision_transformer_layer(
                x, self.args.n_hidden_nodes, i, depth, self.input_shape[0])
        assert len(x.shape) == 2
        x = tf.keras.layers.Dense(all_output_nodes, activation='linear')(x)
        return x


class LSTMModelGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        x = tf.keras.layers.Embedding(self.vocab_size, self.args.n_hidden_nodes)(x)
        hidden_size = self.args.n_hidden_nodes // 2
        for i in range(self.depth - 2):
            tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(hidden_size, return_sequences=True))(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(all_output_nodes, activation='linear')(x)
        return x


class TransformerGenerator(AbstractModelGenerator):
    def get_main_model(self, x, all_output_nodes):
        x1, x2 = get_transformer_model(x, self.args.n_hidden_nodes, self.depth - 1, 0, self.vocab_size)
        x = tf.keras.layers.Dense(all_output_nodes, activation='linear')(x1)
        return x


class SequenceLSTMGenerator(SequenceModelGenerator):
    def get_main_model(self, x, y):
        input_vocab_size = self.input_vocab_size
        target_vocab_size = self.output_vocab_size

        translator = Translator(
            self.args.n_hidden_nodes,
            input_vocab_size,
            target_vocab_size)

        output = translator((x, y))
        return output


class SequenceTransformerGenerator(SequenceModelGenerator):
    def get_main_model(self, x, y):
        input_vocab_size = self.input_vocab_size
        target_vocab_size = self.output_vocab_size

        num_layers = self.depth - 1
        d_model = self.args.n_hidden_nodes
        dff = 4 * self.args.n_hidden_nodes
        num_heads = 8
        dropout_rate = 0.1

        transformer = Transformer(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            dff=dff,
            input_vocab_size=input_vocab_size,
            target_vocab_size=target_vocab_size,
            dropout_rate=dropout_rate)

        output = transformer((x, y))
        return output
