import logging
import os

import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf

tf.executing_eagerly()
tf.compat.v1.enable_eager_execution()

import tensorflow.keras.backend as K
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

from ariel_tests.language.nlp import postprocessSentence, tokenize, NltkGrammarSampler

logger = logging.getLogger(__name__)


class SentenceEmbedding(object):

    def __init__(self, vocabulary, latent_dim, reverse_input):
        self.__dict__.update(vocabulary=vocabulary, latent_dim=latent_dim, reverse_input=reverse_input)

    def __getinitargs__(self):
        return self._init_args

    def getEncoder(self):
        raise NotImplementedError()

    def getDecoder(self):
        raise NotImplementedError()

    def interpolate(self, sentence1, sentence2, nb_samples_test=10, removeDuplicates=False):
        # Encode sentences
        encoder = self.getEncoder()
        code1 = np.array(encoder.encode(sentence1))
        code2 = np.array(encoder.encode(sentence2))

        interpolations = []
        decoder = self.getDecoder()
        for x in np.linspace(0, 1, nb_samples_test):
            c = x * code2 + (1 - x) * code1
            sentence = decoder.decode(c)
            if (not removeDuplicates or
                    len(interpolations) == 0 or
                    (len(interpolations) > 0 and sentence != interpolations[-1])):
                interpolations.append(sentence)
        return interpolations

    def _generateTrainingData(self, generator, batch_size, vocabulary=None):

        if not hasattr(self, 'vocabulary') and not vocabulary == None:
            self.vocabulary = vocabulary

        sentenceGenerator = generator(batch_size)
        while True:
            sentences = next(sentenceGenerator)
            sentences = [postprocessSentence(sentence)
                         for sentence in sentences]
            # NOTE: use offset to reserve a place for the masking symbol at
            # zero
            indices = [self.vocabulary.tokensToIndices(tokenize(sentence))
                       for sentence in sentences]

            maxSentenceLen = len(max(indices, key=len))

            if self.reverse_input:
                # Add a end token to encoder input
                x_enc = pad_sequences([tokens[::-1] for tokens in indices],
                                      maxlen=maxSentenceLen,
                                      value=self.vocabulary.padIndex,
                                      padding='post')
                x_enc = np.array(x_enc, dtype=np.int32)
            else:
                # Add a end token to encoder input
                x_enc = pad_sequences([tokens for tokens in indices],
                                      maxlen=maxSentenceLen,
                                      value=self.vocabulary.padIndex,
                                      padding='post')
                x_enc = np.array(x_enc, dtype=np.int32)

            # Add a end token to decoder input
            x_dec = pad_sequences([[self.vocabulary.startIndex] + tokens for tokens in indices],
                                  maxlen=maxSentenceLen + 1,
                                  value=self.vocabulary.padIndex,
                                  padding='post')
            x_dec = np.array(x_dec, dtype=np.int32)

            # Add a end token to decoder input
            y_dec = pad_sequences([tokens + [self.vocabulary.endIndex] for tokens in indices],
                                  maxlen=maxSentenceLen + 1,
                                  value=self.vocabulary.padIndex,
                                  padding='post')
            y_dec_oh = np.array(indicesToOneHot(y_dec, self.vocabulary.getMaxVocabularySize()),
                                dtype=np.float32)
            yield [x_enc, x_dec], y_dec_oh

    def reconstruct(self, sentence):

        self.model.reset_states()

        # NOTE: use offset to reserve a place for the masking symbol at zero
        indices = [self.vocabulary.padIndex, self.vocabulary.startIndex] + \
                  self.vocabulary.sentenceToIndices(sentence) + \
                  [self.vocabulary.endIndex]

        if self.reverse_input:
            x_enc = np.array(indices[::-1], dtype=np.int32)
        else:
            x_enc = np.array(indices, dtype=np.int32)

        # Add a end token to decoder input
        x_dec = np.array([self.vocabulary.padIndex] + indices, dtype=np.int32)

        xr = self.model.predict(
            [x_enc[np.newaxis, :], x_dec[np.newaxis, :]])

        # NOTE: truncate the sequence to the first found end symbol
        indices = np.squeeze(np.argmax(xr, axis=-1))
        idx = np.where(indices == self.vocabulary.endIndex)[0]
        if len(idx) > 0:
            endPosition = idx[0]
        else:
            # NOTE: if no end symbol was found, ignore the last decoded token
            endPosition = len(indices) - 1
        indices = indices[:endPosition]

        return self.vocabulary.indicesToSentence(indices)

    def coverage(self, grammar, stochastic=False, nb_samples_test=None):

        encoder = self.getEncoder()
        decoder = self.getDecoder()

        nbValid = 0
        nbTotal = 0

        if stochastic:
            # Stochastic sampling of the grammar
            if nb_samples_test is None:
                raise Exception('The number of samples must be specified in stochastic sampling mode!')
            sampler = NltkGrammarSampler(grammar)
            for _ in range(nb_samples_test):
                for sentence in sampler.generate(1):
                    code = encoder.encode(sentence)
                    recons = decoder.decode(code)
                    if recons == sentence:
                        nbValid += 1
                    nbTotal += 1
        else:
            # Depth-first search of the grammar
            from nltk.parse.generate import generate
            for tokens in generate(grammar, n=nb_samples_test):
                sentence = ' '.join(tokens)
                code = encoder.encode(sentence)
                recons = decoder.decode(code)
                if recons == sentence:
                    nbValid += 1
                nbTotal += 1

        coverage = float(nbValid) / nbTotal
        return coverage

    def save(self, filename):
        logger.info('Save method of this embedding not implemented')

    def load(self, filename):
        logger.info('Load method of this embedding not implemented')


class SentenceEncoder(object):

    def encode(self, sentence):
        raise NotImplementedError()


class SentenceDecoder(object):

    def decode(self, z):
        raise NotImplementedError()


def indicesToOneHot(indices, num_tokens):
    return np.eye(num_tokens)[indices]


def plot_history(history, model_filename, epochs):
    if epochs > 0:
        # plot training and validation losses
        model_dir, _ = os.path.split(model_filename)
        embedding_dir, _ = os.path.split(model_dir)
        plot_filename = os.path.join(embedding_dir, 'plots/loss.pdf')
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111)
        ax.plot(history.history['loss'], label='train')
        ax.plot(history.history['val_loss'], label='val')
        ax.set_title('model loss')
        ax.set_ylabel('loss')
        ax.set_xlabel('epoch')
        ax.legend()
        fig.savefig(plot_filename, bbox_inches='tight')


def word_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.argmax(y_true, axis=-1),
                          K.argmax(y_pred, axis=-1)),
                  K.floatx())


def word_accuracy_no_pad(y_true, y_pred):
    argm_true = K.argmax(y_true, axis=-1)
    argm_pred = K.argmax(y_pred, axis=-1)

    pad_value = 0
    mask = K.not_equal(argm_true, pad_value)
    meaningful_true = tf.boolean_mask(argm_true, mask)
    meaningful_pred = tf.boolean_mask(argm_pred, mask)

    acc = K.cast(K.equal(meaningful_true, meaningful_pred),
                 K.floatx())
    b = tf.shape(y_pred)[0]
    l = tf.shape(y_pred)[1]
    acc = tf.expand_dims(acc, 0)
    acc = tf.reduce_mean(acc, 1, keepdims=True)
    acc = tf.tile(acc, [b, l])
    return acc
