import tensorflow as tf
from nlp_transformer import masked_loss


class AbstractModelGenerator(object):
    def __init__(self, args, input_shape, output_nodes):
        self.args = args
        self.input_shape = input_shape
        self.output_nodes = output_nodes
        self.depth = self.args.depth
        self.vocab_size = 0

    def get_main_model(self, x, all_output_nodes):
        raise NotImplementedError()

    def get_structure(self):
        if len(self.input_shape) > 1:
            input_type = tf.float32
        else:
            input_type = tf.int32

        inputs = tf.keras.layers.Input(shape=self.input_shape, dtype=input_type)
        x = self.get_main_model(inputs, 2 * self.output_nodes)
        outputs = tf.split(x, 2, axis=-1)
        model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
        return model

    def get_model(self):
        model = self.get_structure()

        if self.args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(learning_rate=self.args.lr)
        elif self.args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(learning_rate=self.args.lr)
        else:
            raise ValueError(
                '{0} is not a valid optimizer.'.format(self.args.optimizer))

        loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        model.compile(optimizer=optimizer, loss=loss_fn)
        return model, loss_fn

    def set_vocab_size(self, vocab_size):
        self.vocab_size = vocab_size


class SequenceModelGenerator(AbstractModelGenerator):
    def __init__(self, args, input_shape, output_nodes):
        self.args = args
        self.shapes = [input_shape, output_nodes]
        self.input_shape = self.shapes[0] + self.shapes[1]
        self.depth = self.args.depth
        self.input_vocab_size = 0
        self.output_vocab_size = 0

    def get_main_model(self, x, y):
        raise NotImplementedError()

    def get_structure(self):
        shape = (self.input_shape)
        inputs = tf.keras.layers.Input(shape=shape, dtype=tf.int32)
        x, y = tf.split(inputs, [self.shapes[0], self.shapes[1]], -1)
        outputs = self.get_main_model(x, y)
        model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
        return model

    def get_model(self):
        model = self.get_structure()
        adam = tf.keras.optimizers.Adam(learning_rate=self.args.lr)
        loss_fn = masked_loss
        model.compile(optimizer=adam, loss=loss_fn)
        return model, loss_fn

    def set_vocab_size(self, vocab_size):
        self.input_vocab_size = vocab_size[0]
        self.output_vocab_size = vocab_size[1]
