


import logging

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.python.client import timeline

from ariel_tests.language.sentence_generators import TransformerGenerator
from ariel_tests.models.baseEmbedding import SentenceEmbedding, SentenceEncoder, SentenceDecoder, plot_history
from ariel_tests.models.transformer.transformer import Transformer
from ariel_tests.models.transformer.utils import LRSchedulerPerStep

logger = logging.getLogger(__name__)


class TransformerSentenceEncoder(SentenceEncoder):
    def __init__(self, transformer, vocabulary):
        super(TransformerSentenceEncoder, self).__init__()
        self.__dict__.update(transformer=transformer, vocabulary=vocabulary)

    def encode(self, sentence):
        # sentence = "of the five trees that aren ' t green in the middle area , is it to the right of the right - most tree ?"
        # print(sentence)
        indices = [[self.vocabulary.padIndex, self.vocabulary.startIndex] + \
                   self.vocabulary.sentenceToIndices(sentence) + \
                   [self.vocabulary.endIndex]
                   ]

        # print(indices)
        indices = pad_sequences(indices[:self.transformer.maxlen],
                                maxlen=self.transformer.maxlen,
                                value=self.vocabulary.padIndex,
                                padding='post')

        # print(indices)
        z = self.transformer.s2s.encode_model.predict_on_batch(indices)
        z = z.flatten()
        return z


class TransformerSentenceDecoder(SentenceDecoder):
    def __init__(self,
                 transformer,
                 vocabulary):
        super(TransformerSentenceDecoder, self).__init__()
        self.__dict__.update(transformer=transformer, vocabulary=vocabulary)

    def decode(self, z):
        z = z.reshape(1, self.transformer.maxlen, self.transformer.latent_dim)
        sentences = self.transformer.s2s.decode_noise_fast(z)

        tokens = self.vocabulary.sentenceToTokens(sentences[0])
        tokens = self.vocabulary.fromStartToEnd(tokens)
        sentences = self.vocabulary.tokensToSentence(tokens)

        if len(sentences) == 1:
            sentences = sentences[0]
        return sentences


class KerasTransformerEmbedding(SentenceEmbedding):
    def __init__(self,
                 vocabulary,
                 emb_dim=100,
                 rnn='gru',
                 rec_dim=512,
                 latent_dim=256,
                 n_rec=1,
                 reverseInput=True,
                 name='embedding',
                 activation='tanh',
                 grammar=None,
                 maxlen=25,
                 keep=.75,
                 reverse_input=False):

        super(KerasTransformerEmbedding, self).__init__(vocabulary, latent_dim, reverse_input)
        self.__dict__.update(
            emb_dim=emb_dim, rnn=rnn, rec_dim=rec_dim,
            activation=activation, n_rec=n_rec,
            reverseInput=reverseInput, name=name, keep=keep,
            grammar=grammar, maxlen=maxlen)

        if grammar == None:
            raise Exception('This transformer keras implementation needs access to a grammar!!')
        # NOTE: the zero value is reserved for masking
        self.loss = 'categorical_crossentropy'

        self.model = self._autoencoder()

    def _autoencoder(self):
        # HQ 16/512 and GW 512: layers=2
        # GW 16: layers=20
        self.s2s = Transformer(self.vocabulary, len_limit=self.maxlen, d_model=self.latent_dim,
                               d_inner_hid=256,
                               n_head=8, d_k=64, d_v=64, layers=2, dropout=0.1)
        self.s2s.compile(Adam(0.001, 0.9, 0.98, epsilon=1e-9))

        return self.s2s.model

    def getEncoder(self):
        return TransformerSentenceEncoder(self, self.vocabulary)

    def getDecoder(self):
        return TransformerSentenceDecoder(self, self.vocabulary)

    def train(self, gzip_filename_train, gzip_filename_val,
              epochs, steps_per_epoch=32, batch_size=32, callbacks=[],
              model_filename=None, verbose=1, profile=False, log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        lr_scheduler = LRSchedulerPerStep(self.latent_dim, 4000)
        callbacks.append(lr_scheduler)

        generator_train = TransformerGenerator(gzip_filepath=gzip_filename_train, vocabulary=self.vocabulary,
                                               batch_size=batch_size, steps_per_epoch=steps_per_epoch,
                                               maxlen=self.maxlen)
        generator_val = TransformerGenerator(gzip_filepath=gzip_filename_val, vocabulary=self.vocabulary,
                                             batch_size=batch_size, maxlen=self.maxlen)
        try:
            history = self.s2s.model.fit_generator(generator_train,
                                                   epochs=epochs,
                                                   validation_data=generator_val,
                                                   use_multiprocessing=False,
                                                   shuffle=False,
                                                   verbose=verbose,
                                                   callbacks=callbacks)

            plot_history(history, model_filename, epochs)

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.s2s.model.save_weights(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())

    def save(self, filename):
        # print(filename + '_generator.h5')
        # self.s2s.decoder_generator.save(filename + '_generator.h5')
        # self.s2s.encode_model.save(filename + '_encoder.h5')
        logger.info("Tricky to save so many custom layers")

    def load(self, filename):
        # self.s2s.decoder_generator = load_model(filename + '_generator.h5')
        # self.s2s.encode_model = load_model(filename + '_encoder.h5')
        logger.info("Tricky to save so many custom layers")
