import logging

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import AveragePooling1D, Embedding

from ariel_tests.language.sentence_generators import AeGenerator
from ariel_tests.models.IndRNN import IndRNNCell
from ariel_tests.models.rnnEmbedding import KLDivergenceLayer, sampling
# from lnslnsnlsnsnnslsnns.lsnn.keras_custom_layers import ReadOut, RegularizationLoss, MaskedAvPooling, OneHot
# from lnslnsnlsnsnnslsnns.lsnn.spiking_models_keras import Izhikevich, Rulkov, FN, LIF, ALIF_noB, ALIF, Simplest

tf.compat.v1.disable_eager_execution()
from tensorflow.python.client import timeline

from tensorflow.keras import backend as K
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Lambda, RNN, RepeatVector, Permute, Flatten, Softmax
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback

from ariel_tests.models.baseEmbedding import SentenceEmbedding, word_accuracy, SentenceEncoder, SentenceDecoder, \
    plot_history, word_accuracy_no_pad

logger = logging.getLogger(__name__)


class LsnnSentenceEncoder(SentenceEncoder):

    def __init__(self, model, vocabulary, reverseInput=False, code_type='all'):
        super(LsnnSentenceEncoder, self).__init__()
        self.__dict__.update(model=model, vocabulary=vocabulary, reverseInput=reverseInput, code_type=code_type)

    def encode(self, sentence):
        # NOTE: offset by one the index since the embedding uses index = 0 as a
        # masking value
        indices = [self.vocabulary.padIndex, self.vocabulary.startIndex] + \
                  self.vocabulary.sentenceToIndices(sentence) + \
                  [self.vocabulary.endIndex]

        if self.reverseInput:
            x = np.array(indices[::-1], dtype=np.int32)
        else:
            x = np.array(indices, dtype=np.int32)

        z = self.model.predict(x[np.newaxis, :])

        if self.code_type == 'mixed_vt':
            code = np.concatenate([z[0], z[2]], axis=1)
        elif self.code_type == 'voltage':
            code = z[0]
        elif self.code_type == 'threshold':
            code = z[2]
        elif self.code_type == 'all':
            # code = np.concatenate([z[0], z[1], z[2], z[3], z[4]], axis=1)
            states = [state for state in z]
            if len(states) > 1:
                code = np.concatenate(states, axis=1)
            else:
                code = states[0]
        else:
            raise NotImplementedError

        return np.squeeze(code)


class LsnnSentenceDecoder(SentenceDecoder):

    def __init__(self, model, vocabulary, maxlen, code_type='all'):
        super(LsnnSentenceDecoder, self).__init__()
        self.__dict__.update(model=model, vocabulary=vocabulary, maxlen=maxlen, code_type=code_type)

    def _selectTokenIdx(self, dist):
        # Greedy selection based on the most probable token
        idx = np.argmax(dist)
        return idx

    def _statesInCellFormat(self, z):
        state_size = self.model.get_layer("decode-rnn-0").cell.state_size[0]

        if self.code_type == 'mixed_vt':
            old_v, old_b = np.split(z, 2)
            old_v, old_b = old_v[np.newaxis, :], old_b[np.newaxis, :]

            old_z = np.zeros((1, state_size[1]))
            old_ib = np.zeros((1, state_size[3]))
            old_zb = np.zeros((1, state_size[4]))

            states = [old_v, old_z, old_b, old_ib, old_zb]

        elif self.code_type == 'voltage':
            old_v = z[np.newaxis, :]
            old_z = np.zeros((1, state_size[1]))
            old_b = np.zeros((1, state_size[2]))
            old_ib = np.zeros((1, state_size[3]))
            old_zb = np.zeros((1, state_size[4]))

            states = [old_v, old_z, old_b, old_ib, old_zb]

        elif self.code_type == 'threshold':
            raise NotImplementedError

        elif self.code_type == 'all':
            c_state_size = np.cumsum(state_size).tolist()
            states = np.split(z, c_state_size)[:-1]
            states = [np.expand_dims(state, axis=0) for state in states]
        else:
            raise NotImplementedError

        return states

    def decode(self, z, sentence=None, maxSentenceLength=64):

        z = self._statesInCellFormat(z)

        indices = [self.vocabulary.padIndex, self.vocabulary.startIndex]
        for _ in range(self.maxlen):
            indices_arr = np.array([indices], dtype=np.int32)
            dist = self.model.predict([indices_arr] + z)
            dist = np.squeeze(dist)

            # Sample a token from the distribution
            nextIndex = self._selectTokenIdx(dist[-1])
            indices.append(nextIndex)

        tokens = self.vocabulary.indicesToTokens(indices)
        tokens = self.vocabulary.fromStartToEnd(tokens)
        sentence = self.vocabulary.tokensToSentence(tokens)

        if len(sentence) == 1:
            sentence = sentence[0]
        return sentence


def RecurrentUnit(units, n_regular=0, cell_name='rnn', stateful_and_not=False, return_sequences=True, name='',
                  params=None):
    cell = selectCell(cell_name, units, n_regular, params)
    rUnit = RNN([cell] * 1, return_state=True, return_sequences=return_sequences, name=name)

    if stateful_and_not:
        rnn_non_stateful = rUnit
        rnn_stateful = RNN([cell] * 1, return_state=True, return_sequences=return_sequences, stateful=True, name=name)

        # rnn_stateful.set_weights(rnn_non_stateful.get_weights())
        return rnn_non_stateful, rnn_stateful
    return rUnit


def selectCell(cell_name, units, n_regular, params=None):
    # n_regular = 32# 'number of recurrent units.')
    # n_adaptive = 32 # 'number of controller units')
    n_delay = 10  # 'number of delays')
    n_ref = 3  # 'Number of refractory steps')
    beta = 1.8  # 'Mikolov adaptive threshold beta scaling parameter')
    tau_out = 10  # 'tau for PSP decay in LSNN and output neurons')
    tau_v = tau_out
    in_neuron_sign = None
    rec_neuron_sign = None
    tau_a_spread = 20  # related to sentence lengths
    rewiring_connectivity = -1  # 'possible usage of rewiring with ALIF and LIF (0.1 is default)')
    dampening_factor = 0.3  # '')
    thr = .01  # 'threshold at which the LSNN neurons spike')
    tau_a_spread = False  # 'Mikolov model spread of alpha - threshold decay')
    tau_adaptation = 20
    dt = 1.

    if not params == None:
        n_delay = params['delay']
        n_ref = params['refractory_steps']
        dampening_factor = np.float32(params['dampening_factor'])
        thr = params['thr']
        tau_out = params['tau']
        tau_v = tau_out

    if cell_name == 'ALIF':
        cell = ALIF(n_rec=units, tau=tau_v, n_delay=n_delay,
                    n_refractory=n_ref, dt=dt,
                    tau_adaptation=tau_adaptation, beta=beta,
                    thr=thr, rewiring_connectivity=rewiring_connectivity,
                    in_neuron_sign=in_neuron_sign, rec_neuron_sign=rec_neuron_sign,
                    dampening_factor=dampening_factor,
                    n_regular=n_regular
                    )
    elif cell_name == 'ALIFnoB':
        cell = ALIF_noB(n_rec=units, tau=tau_v, n_delay=n_delay,
                        n_refractory=n_ref, dt=dt,
                        tau_adaptation=tau_adaptation, beta=beta,
                        thr=thr,
                        dampening_factor=dampening_factor,
                        n_regular=n_regular
                        )
    elif cell_name == 'LIF':
        cell = LIF(n_rec=units, tau=tau_v, thr=thr,
                   n_refractory=n_ref, dt=dt, n_delay=1,
                   dampening_factor=dampening_factor,
                   )
    elif cell_name == 'Simplest':
        cell = Simplest(n_rec=units, n_delay=1
                        )
    elif cell_name == 'LSTM':
        cell = tf.keras.layers.LSTMCell(units)
    elif cell_name == 'GRU':
        cell = tf.keras.layers.GRUCell(units)
    elif cell_name == 'FN':
        cell = FN(units)
    elif cell_name == 'IndRNN':
        cell = IndRNNCell(units)
    elif cell_name == 'Rulkov':
        cell = Rulkov(units)
    elif cell_name == 'Izhikevich':
        cell = Izhikevich(n_rec=units, tau=tau_v, n_delay=n_delay,
                          n_refractory=n_ref, dt=dt,
                          thr=thr, rewiring_connectivity=rewiring_connectivity,
                          in_neuron_sign=in_neuron_sign, rec_neuron_sign=rec_neuron_sign,
                          dampening_factor=dampening_factor)
    else:
        raise NotImplementedError
    return cell


class LsnnEmbedding(SentenceEmbedding):

    def __init__(self,
                 vocabulary,
                 emb_dim=100,
                 rnn='gru',
                 rec_dim=512,
                 latent_dim=256,
                 n_rec=1,
                 reverse_input=True,
                 n_repeat=5,
                 name='embedding',
                 activation='tanh',
                 keep=1.,
                 maxlen=None,
                 model_filename='embedding_lsnnALIF_gru_16d_1nrl_Q_activationlinear_v1.pkl',
                 code_type='all',
                 noReadOut=True):

        self.dt, self.reg, self.adaptive_reg, self.regularization_f0, self.regularization_f0_max = 1., 1e-2, False, 10 / 1000, 100 / 1000
        tau_out = 10  # 'tau for PSP decay in LSNN and output neurons')
        self.decay = np.exp(-self.dt / tau_out)  # output layer psp decay, chose value between 15 and 30ms as for tau_v

        elementsFilename = model_filename.split('_')
        self.cell_name = [e for e in elementsFilename if 'lsnn' in e][0][4:]

        super(LsnnEmbedding, self).__init__(vocabulary, latent_dim, reverse_input)
        self.__dict__.update(emb_dim=emb_dim, rnn=rnn, rec_dim=rec_dim,
                             activation=activation, n_rec=n_rec,
                             reverseInput=reverse_input, name=name, keep=keep,
                             n_repeat=n_repeat, maxlen=maxlen,
                             code_type=code_type, noReadOut=noReadOut)

        self.vocab_size = self.vocabulary.getMaxVocabularySize()

        # NOTE: the zero value is reserved for masking
        self.loss = 'categorical_crossentropy'
        self.defineModels()

    def initialize_weights(self):

        # embedding and readout

        self.word_embedding = OneHot(self.vocab_size, mask_number=0)
        #self.word_embedding = Embedding(input_dim=self.vocabulary.getMaxVocabularySize(),
        #                                output_dim=self.emb_dim,
        #                                mask_zero=True,
        #                                name='word-embedding')

        self.read_out = ReadOut(self.decay, self.vocab_size, name='out_readout')
        self.last_dense = Dense(self.vocab_size, name='out_dense')

        # encoders

        self.encoders = []
        for i in range(self.n_rec - 1):
            encoder = RecurrentUnit(units=self.rec_dim,
                                    cell_name=self.cell_name,
                                    return_sequences=True,
                                    name='encode-rnn-%s' % (i))
            self.encoders.append(encoder)

        encoder = RecurrentUnit(units=self.latent_dim,
                                cell_name=self.cell_name,
                                return_sequences=False,
                                name='encode-rnn-%s' % (self.n_rec - 1))
        self.encoders.append(encoder)

        # decoders

        self.decoders = []

        for i in range(self.n_rec - 1):
            decoder = RecurrentUnit(units=self.rec_dim,
                                    cell_name=self.cell_name,
                                    name='decode-rnn-%s' % (i + 1))

            self.decoders.append(decoder)

        decoder = RecurrentUnit(units=self.latent_dim,
                                cell_name=self.cell_name,
                                name='decode-rnn-0')
        self.decoders.append(decoder)

    def _encode(self, x):

        output = RepeatVector(self.n_repeat)(x)
        output = Permute((2, 1))(output)
        repeated_question = Flatten()(output)
        xe = self.word_embedding(repeated_question)

        eos = [xe]
        for i in range(self.n_rec - 1):
            eos = self.encoders[i](eos[0])
            # regLoss = RegularizationLoss(self.dt, self.reg, self.regularization_f0, self.regularization_f0_max)
            # eo = eos[0]
            # eo = regLoss(eo)
            # eos = [eo]

        eos = self.encoders[-1](eos[0])
        # regLoss = RegularizationLoss(self.dt, self.reg, self.regularization_f0, self.regularization_f0_max)
        # eo = eos[0]
        # eo = regLoss(eo)
        # eos = [eo] + eos[1:]

        z = eos[1:]
        z = Lambda(lambda x: x, name="z")(z)

        return z

    def _decode(self, x, z=None):
        output = RepeatVector(self.n_repeat)(x)
        output = Permute((2, 1))(output)
        repeated_question = Flatten()(output)
        xe = self.word_embedding(repeated_question)

        stateful = z is None
        if stateful:
            dos = self.decoders[-1](xe)
        else:
            dos = self.decoders[-1](xe, initial_state=z)

        regLoss = RegularizationLoss(self.dt, self.reg, self.regularization_f0, self.regularization_f0_max)
        do = dos[0]
        do = regLoss(do)

        dos = [do] + dos[1:]

        for i in range(self.n_rec - 1):
            dos = self.decoders[i](dos[0])
            regLoss = RegularizationLoss(self.dt, self.reg, self.regularization_f0, self.regularization_f0_max)
            do = dos[0]
            do = regLoss(do)

            dos = [do] + dos[1:]

        z = dos[0]

        if self.noReadOut:
            z = self.last_dense(z)
            z = MaskedAvPooling(self.n_repeat)(z)
        else:
            z = self.read_out(z)
            z = AveragePooling1D(self.n_repeat)(z)

        y = Softmax(axis=-1, name='softmax')(z)
        return y

    def _autoencoder(self):
        x_enc = Input(shape=(None,), dtype=np.int32, name='x_enc')
        x_dec = Input(shape=(None,), dtype=np.int32, name='x_dec')
        z = self._encode(x_enc)
        xr = self._decode(x_dec, z)
        autoencoder = Model([x_enc, x_dec], xr, name='Autoencoder')
        return autoencoder

    def _encoder(self):
        x = Input(shape=(None,), dtype=np.int32, name='x')
        z = self._encode(x)

        encoder = Model(x, z, name='encoder')
        return encoder

    def _decoder(self):
        x = Input(shape=(None,), dtype=np.int32, name='x')

        state_size = self.model.get_layer("decode-rnn-0").cell.state_size[0]
        zs = []
        # general RNN might have several states
        if isinstance(state_size, int): state_size = [state_size]
        for i, latent_dim in enumerate(state_size):
            z = Input(shape=(latent_dim,), dtype=np.float32, name='state_{}'.format(i))
            zs.append(z)

        xr = self._decode(x, [zs])
        decoder = Model([x, zs], xr, name='decoder')
        return decoder

    def defineModels(self):
        self.initialize_weights()
        self.model = self._autoencoder()
        self.encoder = self._encoder()
        self.decoder = self._decoder()

    def getEncoder(self):
        return LsnnSentenceEncoder(self.encoder, self.vocabulary,
                                   self.reverse_input, code_type=self.code_type)

    def getDecoder(self):
        return LsnnSentenceDecoder(self.decoder, self.vocabulary,
                                   self.maxlen, code_type=self.code_type)

    def train(
            self,
            gzip_filename_train,
            gzip_filename_val,
            epochs,
            steps_per_epoch=32,
            batch_size=32,
            callbacks=[],
            model_filename=None,
            verbose=1,
            profile=False,
            log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        optimizer = Adam(lr=1e-4, clipvalue=0.5)
        self.model.compile(optimizer,
                           loss=self.loss,
                           metrics=[word_accuracy, 'categorical_crossentropy', word_accuracy_no_pad],
                           )

        generator_train = AeGenerator(gzip_filepath=gzip_filename_train, vocabulary=self.vocabulary,
                                      batch_size=batch_size, steps_per_epoch=steps_per_epoch,
                                      reverse_input=self.reverse_input, keep=self.keep)
        generator_val = AeGenerator(gzip_filepath=gzip_filename_val, vocabulary=self.vocabulary,
                                    batch_size=batch_size, reverse_input=self.reverse_input, keep=self.keep)
        try:
            history = self.model.fit_generator(generator_train,
                                               epochs=epochs,
                                               validation_data=generator_val,
                                               use_multiprocessing=False,
                                               shuffle=False,
                                               verbose=verbose,
                                               callbacks=callbacks)

            plot_history(history, model_filename, epochs)

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.save(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())


class LsnnVaeEmbedding(LsnnEmbedding):

    def __init__(self, *args, deterministic=False, **kwargs):
        self.deterministic = deterministic
        super(LsnnVaeEmbedding, self).__init__(*args, **kwargs)

    def _encode(self, x):
        xe = self.wordEmbedding(x)

        encoder_outputs = xe
        for i in range(self.n_rec - 1):
            kwargs = {'units': self.rec_dim, 'return_state': False,
                      'return_sequences': True, 'name': 'encode-rnn-%s' % (i),
                      'activation': self.activation}
            encoder = RecurrentUnit(self.rnn, **kwargs)

            encoder_outputs = encoder(encoder_outputs)

        kwargs = {'units': self.latent_dim, 'return_state': True,
                  'return_sequences': False, 'name': 'encode-rnn-%s' % (self.n_rec - 1),
                  'activation': self.activation}
        encoder = RecurrentUnit(self.rnn, **kwargs)

        _, encoder_state = encoder(encoder_outputs)

        # TODO: investigate what should be the proper non-linearity for log-variance output
        # z_mean = Dense(self.latent_dim, activation=self.activation,
        #               name='z_mean')(encoder_state)
        z_mean = Dense(self.latent_dim, activation='linear',
                       name='z_mean')(encoder_state)
        z_log_var = Dense(self.latent_dim, name='z_log_var')(encoder_state)

        z_mean, z_log_var = KLDivergenceLayer()([z_mean, z_log_var])

        if self.deterministic:
            z = z_mean
        else:
            z = Lambda(sampling, output_shape=(self.latent_dim,),
                       name='z')([z_mean, z_log_var])
        return z


# variation on VAE that can encode more complex posteriors


class AnnealingLosses(Callback):

    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

    # customize your behavior
    def on_epoch_end(self, epoch, logs={}):

        new_alpha = K.get_value(self.alpha)
        if epoch < 2:
            new_beta = K.get_value(self.beta)
        else:
            new_beta = min(K.get_value(self.beta) + .34, 1.)

        K.set_value(self.alpha, new_alpha)
        K.set_value(self.beta, new_beta)

        logger.info(" epoch %s, alpha = %s, beta = %s" % (epoch, K.get_value(self.alpha), K.get_value(self.beta)))


class LsnnVaeEmbedding2(LsnnEmbedding):

    def __init__(self, *args, deterministic=False, **kwargs):
        self.deterministic = deterministic
        super(LsnnVaeEmbedding2, self).__init__(*args, **kwargs)

    def _encode(self, x):
        xe = self.wordEmbedding(x)

        encoder_outputs = xe
        for i in range(self.n_rec - 1):
            kwargs = {'units': self.rec_dim, 'return_state': False,
                      'return_sequences': True, 'name': 'encode-rnn-%s' % (i),
                      'activation': self.activation}
            encoder = RecurrentUnit(self.rnn, **kwargs)

            encoder_outputs = encoder(encoder_outputs)

        kwargs = {'units': self.latent_dim, 'return_state': True,
                  'return_sequences': False, 'name': 'encode-rnn-%s' % (self.n_rec - 1),
                  'activation': self.activation}
        encoder = RecurrentUnit(self.rnn, **kwargs)

        _, encoder_state = encoder(encoder_outputs)

        # TODO: investigate what should be the proper non-linearity for log-variance output
        # z_mean = Dense(self.latent_dim, activation=self.activation,
        #                name='z_mean')(encoder_state)
        # VAE Z layer
        z_mean = Dense(self.latent_dim)(encoder_state)
        z_log_sigma = Dense(self.latent_dim)(encoder_state)

        def sample(args):
            z_mean, z_log_sigma = args
            batch_size = K.shape(z_mean)[0]

            epsilon_std = 1.
            epsilon = K.random_normal(shape=(batch_size, self.latent_dim),
                                      mean=0., stddev=epsilon_std)
            return z_mean + z_log_sigma * epsilon

        # note that "output_shape" isn't necessary with the TensorFlow backend
        # so you could write `Lambda(sampling)([z_mean, z_log_sigma])`

        if self.deterministic:
            z = z_mean
        else:
            z = Lambda(sample, output_shape=(self.latent_dim,),
                       name='z')([z_mean, z_log_sigma])

        def kl_loss(x, x_decoded_mean):
            kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma))
            return kl_loss

        def vae_loss(x, x_decoded_mean):
            xent_loss = K.categorical_crossentropy(x, x_decoded_mean)
            kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma))
            loss = xent_loss + kl_loss
            return loss

        self.loss_cat = K.categorical_crossentropy
        self.loss_kl = kl_loss
        self.loss = vae_loss
        return z

    def _autoencoder(self):
        x_enc = Input(shape=(None,), dtype=np.int32, name='x_enc')
        x_dec = Input(shape=(None,), dtype=np.int32, name='x_dec')
        z = self._encode(x_enc)
        xr = self._decode(x_dec, z)
        autoencoder = Model([x_enc, x_dec], xr, name='Autoencoder')

        self.annealing_model = Model([x_enc, x_dec], [xr, xr], name='annealed_Autoencoder')
        # autoencoder.summary()
        return autoencoder

    def _generateTrainingData_forAnnealing(self, generator, batch_size):
        new_generator = self._generateTrainingData(generator, batch_size)
        while True:
            batch = next(new_generator)
            yield batch[0], [batch[1], batch[1]]

    def train(self, generator, val_data,
              epochs, steps_per_epoch=32, batch_size=32, callbacks=[],
              model_filename=None, verbose=1, profile=False, log_path=None):

        if profile:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

        optimizer = Adam(lr=1e-4, clipvalue=0.5)

        weight_cat = K.variable(1.0)
        weight_kl = K.variable(0.0)
        self.annealing_model.compile(optimizer,
                                     loss=[self.loss_cat, self.loss_kl],
                                     loss_weights=[weight_cat, weight_kl],
                                     metrics=[word_accuracy, 'categorical_crossentropy'], )

        callbacks.append(AnnealingLosses(weight_cat, weight_kl))

        try:
            self.annealing_model.fit_generator(self._generateTrainingData_forAnnealing(generator, batch_size),
                                               epochs=epochs,
                                               steps_per_epoch=steps_per_epoch,
                                               validation_data=val_data,
                                               use_multiprocessing=False,
                                               shuffle=False,
                                               verbose=verbose,
                                               callbacks=callbacks)

        except KeyboardInterrupt:
            logger.info("Training interrupted by the user")
            if model_filename is not None:
                self.model.save_weights(model_filename)

        if model_filename is not None:
            self.model.save_weights(model_filename)

        if profile:
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            with open('timeline.ctf.json', 'w') as f:
                f.write(trace.generate_chrome_trace_format())
