import logging
import os

import numpy as np

import tensorflow as tf

tf.executing_eagerly()
tf.compat.v1.enable_eager_execution()

from tensorflow.keras import Model, Input
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Dense, Lambda, Layer, Embedding, \
    GRU, LSTM, Softmax
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.python.client import timeline

from ariel_tests.language.sentence_generators import AeGenerator, VaeGenerator
from ariel_tests.models.baseEmbedding import SentenceEmbedding, word_accuracy, SentenceEncoder, SentenceDecoder, \
    plot_history, word_accuracy_no_pad

logger = logging.getLogger(__name__)

'''
some snippets have been taken from:
    http://tiao.io/posts/implementing-variational-autoencoders-in-keras-beyond-the-quickstart-tutorial/
    https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py
    https://github.com/bjlkeng/sandbox/blob/master/notebooks/vae-inverse_autoregreASDssive_flows/made.py
'''


# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.
    # Arguments:    
        args (tensor): mean and log of variance of Q(z|X)
    # Returns:
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


class KLDivergenceLayer(Layer):
    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):
        mu, log_var = inputs

        kl_batch = -.5 * K.sum(1 + log_var -
                               K.square(mu) -
                               K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs


class RnnSentenceEncoder(SentenceEncoder):

    def __init__(self, model, vocabulary, reverseInput=False):
        super(RnnSentenceEncoder, self).__init__()
        self.__dict__.update(model=model, vocabulary=vocabulary, reverseInput=reverseInput)

    def encode(self, sentence):
        # NOTE: offset by one the index since the embedding uses index = 0 as a
        # masking value
        indices = [self.vocabulary.padIndex, self.vocabulary.startIndex] + \
                  self.vocabulary.sentenceToIndices(sentence) + \
                  [self.vocabulary.endIndex]

        if self.reverseInput:
            x = np.array(indices[::-1], dtype=np.int32)
        else:
            x = np.array(indices, dtype=np.int32)

        z = self.model.predict(x[np.newaxis, :])
        return np.squeeze(z)


class RnnSentenceDecoder(SentenceDecoder):

    def __init__(self, model, vocabulary, maxlen):
        super(RnnSentenceDecoder, self).__init__()
        self.__dict__.update(model=model, vocabulary=vocabulary, maxlen=maxlen)

    def _selectTokenIdx(self, dist):
        # Greedy selection based on the most probable token
        idx = np.argmax(dist)
        return idx

    def decode(self, z):

        indices = [self.vocabulary.padIndex, self.vocabulary.startIndex]
        for _ in range(self.maxlen):
            indices_arr = np.array([indices], dtype=np.int32)
            dist = self.model.predict([indices_arr, z[None]])
            dist = np.squeeze(dist)

            # Sample a token from the distribution
            nextIndex = self._selectTokenIdx(dist[-1])
            indices.append(nextIndex)

        tokens = self.vocabulary.indicesToTokens(indices)
        tokens = self.vocabulary.fromStartToEnd(tokens)
        sentence = self.vocabulary.tokensToSentence(tokens)

        if len(sentence) == 1:
            sentence = sentence[0]
        return sentence


def RecurrentUnit(rnn, **kwargs):
    if rnn == 'cudnngru':
        rUnit = None  # CuDNNGRU(**kwargs)
    elif rnn == 'gru':
        rUnit = GRU(**kwargs)
    elif rnn == 'cudnnlstm':
        rUnit = None  # CuDNNLSTM(**kwargs)
    elif rnn == 'lstm':
        rUnit = LSTM(**kwargs)
    elif rnn == 'indrnn':
        rUnit = None  # IndRNN(**{**kwargs,
        #     'activation': selu,
        #     'kernel_initializer': 'lecun_normal'})
    else:
        rUnit = LSTM(**kwargs)
        logger.warning("""The recurrent unit you specified is not implemented,
        we run the code on a lstm instead""")

    return rUnit


class RnnEmbedding(SentenceEmbedding):

    def __init__(self,
                 vocabulary,
                 emb_dim=100,
                 rnn='gru',
                 rec_dim=512,
                 latent_dim=256,
                 n_rec=1,
                 maxlen=25,
                 reverse_input=True,
                 name='embedding',
                 activation='tanh',
                 keep=1.):

        super(RnnEmbedding, self).__init__(vocabulary, latent_dim, reverse_input)
        self.__dict__.update(emb_dim=emb_dim, rnn=rnn, rec_dim=rec_dim,
                             activation=activation, maxlen=maxlen,
                             n_rec=n_rec, reverse_input=reverse_input, name=name, keep=keep)

        # NOTE: the zero value is reserved for masking
        self.loss = 'categorical_crossentropy'

        self.defineModels()

    def initialize_encoder(self):
        self.word_embedding = Embedding(input_dim=self.vocabulary.getMaxVocabularySize(),
                                        output_dim=self.emb_dim,
                                        mask_zero=True,
                                        name='word-embedding')

        self.list_encoders = []
        for i in range(self.n_rec - 1):
            kwargs = {'units': self.rec_dim,
                      'return_state': False,
                      'return_sequences': True,
                      'name': 'encode-rnn-%s' % (i),
                      'activation': self.activation}
            encoder = RecurrentUnit(self.rnn, **kwargs)

            self.list_encoders.append(encoder)

        kwargs = {'units': self.latent_dim,
                  'return_state': True,
                  'return_sequences': False,
                  'name': 'encode-rnn-%s' % (self.n_rec - 1),
                  'activation': self.activation}
        encoder = RecurrentUnit(self.rnn, **kwargs)
        self.list_encoders.append(encoder)

    def _encode(self, x):
        xe = self.word_embedding(x)

        encoder_outputs = xe
        for i in range(self.n_rec - 1):
            encoder = self.list_encoders[i]
            encoder_outputs = encoder(encoder_outputs)

        encoder = self.list_encoders[-1]
        _, z = encoder(encoder_outputs)

        z = Lambda(lambda x: x, name="z")(z)

        return z

    def initialize_decoder(self):
        kwargs = {'units': self.latent_dim,
                  'return_sequences': True,
                  'stateful': False,
                  'name': 'decode-rnn-0',
                  'activation': self.activation}
        self.decoder_0 = RecurrentUnit(self.rnn, **kwargs)

        self.list_decoders = []
        for i in range(self.n_rec - 1):
            kwargs = {'units': self.rec_dim,
                      'return_sequences': True,
                      'stateful': False,
                      'name': 'decode-rnn-%s' % (i + 1),
                      'activation': self.activation}
            decoder = RecurrentUnit(self.rnn, **kwargs)
            self.list_decoders.append(decoder)

        self.last_dense = Dense(self.vocabulary.getMaxVocabularySize(), activation='softmax', name='out')

    def _decode(self, x, z=None):
        xe = self.word_embedding(x)

        if z is not None:
            decoder_outputs = self.decoder_0(xe, initial_state=[z])
        else:
            decoder_outputs = self.decoder_0(xe)

        for i in range(self.n_rec - 1):
            decoder = self.list_decoders[i]
            decoder_outputs = decoder(decoder_outputs)

        z = self.last_dense(decoder_outputs)
        return z

    def _autoencoder(self):
        x_enc = Input(shape=(None,), dtype=np.int32, name='x_enc')
        x_dec = Input(shape=(None,), dtype=np.int32, name='x_dec')
        z = self._encode(x_enc)
        xr = self._decode(x_dec, z)
        autoencoder = Model([x_enc, x_dec], xr, name='autoencoder')
        return autoencoder

    def _encoder(self):
        x = Input(shape=(None,), dtype=np.int32, name='x_ee')
        z = self._encode(x)

        encoder = Model(x, z, name='encoder')
        return encoder

    def _decoder(self):
        x = Input(shape=(None,), dtype=np.int32, name='x_dd')
        z = Input(shape=(self.latent_dim,), dtype=np.float32, name='state')

        xr = self._decode(x, z)
        decoder = Model([x, z], xr, name='decoder')
        return decoder

    def defineModels(self):
        self.initialize_decoder()
        self.initialize_encoder()
        self.model = self._autoencoder()
        self.encoder = self._encoder()
        self.decoder = self._decoder()

    def getEncoder(self):
        return RnnSentenceEncoder(self.encoder, self.vocabulary, self.reverse_input)

    def getDecoder(self):
        return RnnSentenceDecoder(self.decoder, self.vocabulary, self.maxlen)

    def train(
            self,
            gzip_filename_train,
            gzip_filename_val,
            epochs,
            steps_per_epoch=32,
            batch_size=32,
            callbacks=[],
            model_filename=None,
            verbose=1,
            profile=False,
            log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        optimizer = Adam(lr=1e-4, clipvalue=0.5)
        self.model.compile(
            optimizer,
            loss=self.loss,
            metrics=[word_accuracy, 'categorical_crossentropy', word_accuracy_no_pad],
        )

        generator_train = AeGenerator(gzip_filepath=gzip_filename_train, vocabulary=self.vocabulary,
                                      batch_size=batch_size, steps_per_epoch=steps_per_epoch,
                                      reverse_input=self.reverse_input, keep=self.keep)
        generator_val = AeGenerator(gzip_filepath=gzip_filename_val, vocabulary=self.vocabulary,
                                    batch_size=batch_size, reverse_input=self.reverse_input, keep=self.keep)
        try:
            self.model.summary()
            history = self.model.fit_generator(generator_train,
                                               epochs=epochs,
                                               validation_data=generator_val,
                                               use_multiprocessing=False,
                                               shuffle=False,
                                               verbose=verbose,
                                               callbacks=callbacks)

            plot_history(history, model_filename, epochs)

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.save(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())

    def save(self, filename):
        self.model.save(filename + '.h5')
        self.encoder.save(filename + '_encoder.h5')
        self.decoder.save(filename + '_decoder.h5')

    def load(self, filename):
        logger.info('Loading embedding from file: %s' % (os.path.abspath(filename)))

        self.model = load_model(filename + '.h5',
                                custom_objects={'word_accuracy': word_accuracy,
                                                'word_accuracy_no_pad': word_accuracy_no_pad})
        self.encoder = load_model(filename + '_encoder.h5',
                                  custom_objects={'word_accuracy': word_accuracy,
                                                  'word_accuracy_no_pad': word_accuracy_no_pad})
        self.decoder = load_model(filename + '_decoder.h5',
                                  custom_objects={'word_accuracy': word_accuracy,
                                                  'word_accuracy_no_pad': word_accuracy_no_pad})


class AnnealingLosses(Callback):

    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta
        self.epoch = 0

    # customize your behavior
    def on_epoch_begin(self, epoch, logs={}):
        self.epoch = epoch

    # customize your behavior
    def on_batch_end(self, epoch, logs={}):

        new_alpha = K.get_value(self.alpha)
        new_beta = K.get_value(self.beta)
        if self.epoch == 7:
            new_beta = min(new_beta + .001, 1.)
        elif self.epoch == 8:
            new_beta = 1.

        K.set_value(self.alpha, new_alpha)
        K.set_value(self.beta, new_beta)

        # logger.info(" epoch %s, alpha = %s, beta = %s" % (epoch, K.get_value(self.alpha), K.get_value(self.beta)))


def KL_loss(z_log_sigma, z_mean):
    def kl_loss(y_true, y_pred):
        time_steps = K.shape(y_true)[1]

        kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=[1])
        expanded_kl = tf.expand_dims(kl_loss, 1)
        kl_loss = tf.tile(expanded_kl, [1, time_steps])

        time_steps = tf.dtypes.cast(time_steps, tf.float32)
        kl_loss = tf.math.divide(kl_loss, time_steps)
        return kl_loss

    return kl_loss


def CAT_loss(z_log_sigma, z_mean):
    def cat_loss(y_true, y_pred):
        xent_loss = K.categorical_crossentropy(y_true, y_pred)
        return xent_loss

    return cat_loss


def VAE_loss(z_log_sigma, z_mean):
    def vae_loss(y_true, y_pred):
        xent_loss = K.categorical_crossentropy(y_true, y_pred)
        kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=[1])
        loss = xent_loss + kl_loss
        return loss

    return vae_loss


def sample(args):
    z_mean, z_log_sigma = args
    batch_size = K.shape(z_mean)[0]
    latent_dim = K.shape(z_mean)[1]

    epsilon_std = 1.
    epsilon = K.random_normal(shape=(batch_size, latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + z_log_sigma * epsilon


class RnnVaeEmbedding(RnnEmbedding):

    def __init__(self, *args, **kwargs):
        super(RnnVaeEmbedding, self).__init__(*args, **kwargs)

    def initialize_encoder(self):
        self.word_embedding = Embedding(input_dim=self.vocabulary.getMaxVocabularySize(),
                                        output_dim=self.emb_dim,
                                        mask_zero=True,
                                        name='word-embedding')

        self.list_encoders = []
        for i in range(self.n_rec - 1):
            kwargs = {'units': self.rec_dim,
                      'return_state': False,
                      'return_sequences': True,
                      'name': 'encode-rnn-%s' % (i),
                      'activation': self.activation}
            encoder = RecurrentUnit(self.rnn, **kwargs)

            self.list_encoders.append(encoder)

        kwargs = {'units': self.latent_dim,
                  'return_state': True,
                  'return_sequences': False,
                  'name': 'encode-rnn-%s' % (self.n_rec - 1),
                  'activation': self.activation}
        encoder = RecurrentUnit(self.rnn, **kwargs)
        self.list_encoders.append(encoder)

        self.dense_mean = Dense(self.latent_dim)
        self.dense_sigma = Dense(self.latent_dim)

    def _encode(self, x):
        xe = self.word_embedding(x)

        encoder_outputs = xe
        for i in range(self.n_rec - 1):
            encoder = self.list_encoders[i]
            encoder_outputs = encoder(encoder_outputs)

        encoder = self.list_encoders[-1]
        _, z = encoder(encoder_outputs)

        # VAE Z layer
        z_mean = self.dense_mean(z)
        z_log_sigma = self.dense_sigma(z)

        if self.deterministic:
            z = z_mean
        else:
            z = Lambda(sample, name='sampled_z')([z_mean, z_log_sigma])

        self.loss_cat = CAT_loss(z_log_sigma, z_mean)
        self.loss_kl = KL_loss(z_log_sigma, z_mean)
        self.loss = VAE_loss(z_log_sigma, z_mean)
        return z

    def _autoencoder(self):
        self.deterministic = False
        x_enc = Input(shape=(None,), dtype=np.int32, name='x_enc')
        x_dec = Input(shape=(None,), dtype=np.int32, name='x_dec')
        z = self._encode(x_enc)
        xr = self._decode(x_dec, z)
        autoencoder = Model([x_enc, x_dec], xr, name='Autoencoder')

        self.annealing_model = Model([x_enc, x_dec], [xr, xr], name='annealed_Autoencoder')
        return autoencoder

    def defineModels(self):
        self.initialize_decoder()
        self.initialize_encoder()
        self.model = self._autoencoder()
        self.decoder = self._decoder()

    def _encoder(self):
        self.deterministic = True
        x = Input(shape=(None,), dtype=np.int32, name='x_enc_vae')
        z = self._encode(x)

        encoder = Model(x, z, name='encoder')
        return encoder

    def getEncoder(self):
        return RnnSentenceEncoder(self.encoder, self.vocabulary, self.reverse_input)

    def _generateTrainingData_forAnnealing(self, generator, batch_size):
        new_generator = self._generateTrainingData(generator, batch_size)
        while True:
            batch = next(new_generator)
            yield batch[0], [batch[1], batch[1]]

    def train(
            self,
            gzip_filename_train,
            gzip_filename_val,
            epochs,
            steps_per_epoch=32,
            batch_size=32,
            callbacks=[],
            model_filename=None,
            verbose=1,
            profile=False,
            log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        optimizer = Adam(lr=1e-4, clipvalue=0.5)

        weight_cat = K.variable(1.0)
        weight_kl = K.variable(0.0)
        self.annealing_model.summary()
        self.annealing_model.compile(optimizer,
                                     loss=[self.loss_kl, self.loss_cat],
                                     loss_weights=[weight_kl, weight_cat],
                                     metrics=['categorical_crossentropy', word_accuracy, word_accuracy_no_pad])
        callbacks.append(AnnealingLosses(weight_cat, weight_kl))

        generator_train = VaeGenerator(gzip_filepath=gzip_filename_train, vocabulary=self.vocabulary,
                                       batch_size=batch_size, steps_per_epoch=steps_per_epoch,
                                       reverse_input=self.reverse_input)
        generator_val = VaeGenerator(gzip_filepath=gzip_filename_val, vocabulary=self.vocabulary,
                                     batch_size=batch_size, reverse_input=self.reverse_input)
        try:
            history = self.annealing_model.fit_generator(generator_train,
                                                         epochs=epochs,
                                                         validation_data=generator_val,
                                                         use_multiprocessing=False,
                                                         shuffle=False,
                                                         verbose=verbose,
                                                         callbacks=callbacks)

            plot_history(history, model_filename, epochs)

            # I defined the encoder here, cause it was giving problems with the definition of loss_kl
            self.encoder = self._encoder()

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.model.save_weights(model_filename)

        # if model_filename is not None:
        #    self.model.save_weights(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())
