import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.python.client import timeline

from ariel_tests.language.sentence_generators import ArielGenerator, getNoisyInput
from ariel_tests.models.AriEL.AriEL import AriEL
from ariel_tests.models.AriEL.keras_layers import predefined_model
from ariel_tests.models.baseEmbedding import SentenceEmbedding, SentenceEncoder, SentenceDecoder, plot_history

logger = logging.getLogger(__name__)


class lmArielSentenceEncoder(SentenceEncoder):

    def __init__(self,
                 ariel,
                 vocabulary,
                 latent_dim,
                 maxlen):
        super(lmArielSentenceEncoder, self).__init__()
        self.__dict__.update(ariel=ariel, vocabulary=vocabulary, maxlen=maxlen)

        encoder_input = Input(shape=(None,), name='decoder_input')
        continuous_output = ariel.encode(encoder_input)
        self.encoder_model = Model(inputs=encoder_input, outputs=continuous_output)

    def encode(self, sentences):
        if isinstance(sentences, str):
            sentences = [sentences]

        list_indices = []
        for sentence in sentences:
            indices = [self.vocabulary.padIndex, self.vocabulary.startIndex] + \
                      self.vocabulary.sentenceToIndices(sentence) + [self.vocabulary.endIndex]
            list_indices.append(indices)
        np_i = np.array(list_indices)
        np_i = pad_sequences(np_i, maxlen=self.maxlen, value=self.vocabulary.padIndex, padding='pre')

        z = self.encoder_model.predict(np_i)
        if len(sentences) == 1:
            z = z[0]
        return z


class lmArielSentenceDecoder(SentenceDecoder):

    def __init__(self,
                 ariel,
                 vocabulary,
                 latent_dim,
                 maxlen):
        super(lmArielSentenceDecoder, self).__init__()
        self.__dict__.update(ariel=ariel, vocabulary=vocabulary)

        decoder_input = Input(shape=(latent_dim,), name='decoder_input')
        discrete_output = ariel.decode(decoder_input)
        self.decoder_model = Model(inputs=decoder_input, outputs=discrete_output)

    def decode(self, z, return_indices=False):
        one_by_one = len(z.shape) == 1
        if one_by_one:
            z = z[None]
        list_indices = self.decoder_model.predict(z)[0]
        sentences = []
        for indices in list_indices:
            tokens = self.vocabulary.indicesToTokens(indices.astype(int))
            if one_by_one:
                tokens = self.vocabulary.removeSpecialTokens(tokens)
            sentence = self.vocabulary.tokensToSentence(tokens)
            sentences.append(sentence)

        if return_indices:
            sentences = list_indices

        if one_by_one:
            sentences = sentences[0]
        return sentences


class LmArielEmbedding(SentenceEmbedding):

    def __init__(self,
                 vocabulary,
                 emb_dim=100,
                 latent_dim=512,
                 maxlen=25,
                 keep=1.,
                 reverse_input=False,
                 language_model=None,
                 size_lat_dim=1.):

        super(LmArielEmbedding, self).__init__(vocabulary, latent_dim, reverse_input)
        self.__dict__.update(vocabulary=vocabulary,
                             latent_dim=latent_dim,
                             emb_dim=emb_dim,
                             keep=keep,
                             maxlen=maxlen,
                             size_lat_dim=size_lat_dim)

        self.vocab_size = vocabulary.getMaxVocabularySize()

        # NOTE: the zero value is reserved for masking
        self.loss = 'categorical_crossentropy'

        self.model = language_model if not language_model is None else self._languageModel()
        optimizer = Adam(lr=.001)  # , clipnorm=1.    
        self.model.compile(loss=self.loss, optimizer=optimizer, metrics=['acc'])

        self.encoder_type = 1  # 1
        self.decoder_type = 2  # 2
        self.ariel = AriEL(vocab_size=self.vocab_size,
                           emb_dim=self.emb_dim,
                           lat_dim=self.latent_dim,
                           maxlen=self.maxlen,
                           output_type='both',
                           language_model=self.model,
                           decoder_type=self.decoder_type,
                           encoder_type=self.encoder_type,
                           size_lat_dim=self.size_lat_dim,
                           PAD=self.vocabulary.padIndex)

    def _languageModel(self):
        return predefined_model(self.vocab_size, self.emb_dim, units=140)

    def getEncoder(self):
        return lmArielSentenceEncoder(self.ariel,
                                      self.vocabulary,
                                      self.latent_dim,
                                      self.maxlen)

    def getDecoder(self):
        return lmArielSentenceDecoder(self.ariel,
                                      self.vocabulary,
                                      self.latent_dim,
                                      self.maxlen)

    def _generateTrainingData_ariel(self, generator, batch_size):

        sentenceGenerator = generator(batch_size)
        while True:
            sentences = next(sentenceGenerator)
            indeces = self.vocabulary.sentencesToIndices(sentences)

            list_input = []
            list_output = []
            for listOfIndices in indeces:
                indices = [self.vocabulary.padIndex, self.vocabulary.startIndex] + listOfIndices + [
                    self.vocabulary.endIndex]
                sentence_len = len(listOfIndices)
                next_token_pos = np.random.randint(sentence_len)
                input_indices = indices[:next_token_pos]
                next_token = [indices[next_token_pos]]

                list_input.append(input_indices)
                list_output.append(next_token)

            input_indices = pad_sequences(list_input, maxlen=self.maxlen,
                                          value=self.vocabulary.padIndex,
                                          padding='pre')
            output_indices = np.array(list_output)

            output_indices = to_categorical(output_indices, num_classes=self.vocab_size)

            input_indices = getNoisyInput(self.keep, input_indices)
            yield input_indices, output_indices

    def train(self, gzip_filename_train, gzip_filename_val,
              epochs, steps_per_epoch=32, batch_size=32, callbacks=[],
              model_filename=None, verbose=1, profile=False, log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        generator_train = ArielGenerator(gzip_filepath=gzip_filename_train, vocabulary=self.vocabulary,
                                         batch_size=batch_size, steps_per_epoch=steps_per_epoch, )
        generator_val = ArielGenerator(gzip_filepath=gzip_filename_val, vocabulary=self.vocabulary,
                                       batch_size=batch_size)
        try:
            history = self.model.fit_generator(generator_train,
                                               epochs=epochs,
                                               validation_data=generator_val,
                                               use_multiprocessing=False,
                                               shuffle=False,
                                               verbose=verbose,
                                               callbacks=callbacks)

            plot_history(history, model_filename, epochs)

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.model.save_weights(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())

    def save(self, filename):
        self.model.save(filename + '.h5')

    def load(self, filename):
        logger.info('Loading embedding from file: %s' % (os.path.abspath(filename)))

        if not '.h5' in filename: filename += '.h5'
        self.model = load_model(filename)

        self.ariel = AriEL(vocab_size=self.vocab_size,
                           emb_dim=self.emb_dim,
                           lat_dim=self.latent_dim,
                           maxlen=self.maxlen,
                           output_type='both',
                           language_model=self.model,
                           decoder_type=self.decoder_type,
                           encoder_type=self.encoder_type,
                           size_lat_dim=1.,
                           PAD=self.vocabulary.padIndex)
