import logging
import os

import numpy as np

import tensorflow as tf

tf.executing_eagerly()
tf.compat.v1.enable_eager_execution()

from tensorflow.keras.callbacks import ReduceLROnPlateau

from ariel_tests.language.nlp import Vocabulary
from ariel_tests.models.baseEmbedding import SentenceEmbedding
# from ariel_tests.models.lsnnEmbedding import LsnnEmbedding

from ariel_tests.models.rnnEmbedding import RnnEmbedding, RnnVaeEmbedding
from ariel_tests.models.transformerEmbedding import KerasTransformerEmbedding
from ariel_tests.models.naiveArielEmbedding import ArithmeticCodingEmbedding
from ariel_tests.models.lmArielEmbedding import LmArielEmbedding

# from ariel_tests.models.lsnnEmbedding import LsnnEmbedding

logger = logging.getLogger(__name__)


def getEmbedding(model_filename,
                 gzip_filename_train=None,
                 gzip_filename_val=None,
                 grammar=None,
                 epochs=117,
                 steps_per_epoch=32,
                 batch_size=256):
    if not os.path.exists(model_filename):
        logger.info('Training embedding to file: %s' % (model_filename))

        # Get grammar and vocabulary from oracle
        vocabulary = Vocabulary.fromGrammar(grammar)

        # get the specific embedding desired

        elementsFilename = model_filename.split('_')

        try:
            latent_dim = int([e[:-1] for e in elementsFilename if e[-1] == 'd'][0])
        except Exception:
            latent_dim = 16
            logger.warning("""Specify the number of units in the latent space
            in the model_filename, like #d
            (e.g. model_filename = embedding_vae_gru_128d_2nrl_v1.pkl)
            This time it has been set as 16""")

        vocab_size = vocabulary.getMaxVocabularySize()
        emb_dim = int(np.sqrt(vocab_size)) + 1  # 64 #

        if 'cudnngru' in model_filename:
            rnn = 'cudnngru'
        elif 'gru' in model_filename:
            rnn = 'gru'
        elif 'indrnn' in model_filename:
            rnn = 'indrnn'
        elif 'transformer' in model_filename:
            rnn = None
        elif 'lsnn' in model_filename:
            rnn = 'lsnn'
        else:
            rnn = 'gru'
            logger.warning("""The only rnn implemented are "gru" and "indrnn" 
            Specify the one to use in the model_filename
            (e.g. model_filename = embedding_vae_gru_128d_2nrl_v1.pkl)
            This time it has been set as "gru" """)

        try:
            n_rec = int([e[:-3] for e in elementsFilename if e[-3:] == 'nrl'][0])
        except Exception:
            n_rec = 2
            logger.warning("""Specify the number of recurrent layers in the embedding
            in the model_filename, like #nrl
            (e.g. model_filename = embedding_vae_gru_128d_2nrl_v1.pkl)
            This time it has been set as 2 """)

        try:
            activation = [e.replace("activation", "") for e in elementsFilename if 'activation' in e][0]
        except Exception:
            activation = 'tanh'
            logger.warning("""Specify the activation in the embedding
            in the model_filename, like activationtanh
            (e.g. model_filename = embedding_vae_cudnngru_%dd_2nrl_Q_activationlinear_v1.pkl)
            This time it has been set as tanh """)

        try:
            keep = [e.replace("keep", "") for e in elementsFilename if 'keep' in e][0]
            keep = int(keep) / 100
        except Exception:
            keep = 1.

        maxlen = 70 #25
        kwargs = {'vocabulary': vocabulary,
                  'emb_dim': emb_dim,
                  'rec_dim': 128,
                  'latent_dim': latent_dim,
                  'n_rec': n_rec,
                  'rnn': rnn,
                  'name': model_filename,
                  'activation': activation,
                  'keep': keep,
                  'maxlen': maxlen
                  }

        if 'vae' in model_filename:
            embedding = RnnVaeEmbedding(**kwargs)
        elif 'ae' in model_filename:
            embedding = RnnEmbedding(**kwargs)
        elif 'transformer' in model_filename:
            kwargs.update({'grammar': grammar})
            embedding = KerasTransformerEmbedding(**kwargs)  # , grammar)
        elif 'lmariel' in model_filename:
            embedding = LmArielEmbedding(vocabulary=vocabulary,
                                         emb_dim=emb_dim,
                                         latent_dim=latent_dim)
        elif 'nariel' in model_filename:
            embedding = ArithmeticCodingEmbedding(grammar, ndim=latent_dim,
                                                  precision=5,
                                                  transform=None,
                                                  name=model_filename)  # transform = 'orthonormal'
        else:
            raise Exception("""Define the type of embedding in the filename. The options implemented are
            '_ae_', '_vae_', '_iaf-vae_', '_transformer_', '_lmariel_' and '_arithmetic_' """)

        # prepare the data generation

        logger.warning("""\n\nTraining with Biased Dataset Generator""")

        slash_positions = [pos for pos, char in enumerate(model_filename) if char == '/']
        log_path = model_filename[:slash_positions[-2]] + '/log/'

        if hasattr(embedding, 'model'):
            embedding.model.summary()

            # prepare callbacks
            callbacks = getCallbacks()

            # train model
            embedding.train(gzip_filename_train=gzip_filename_train, gzip_filename_val=gzip_filename_val,
                            epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, callbacks=callbacks,
                            model_filename=model_filename, log_path=log_path
                            )

            embedding.save(model_filename)

    else:
        logger.info('Loading embedding from file: %s' % (model_filename))
        embedding = SentenceEmbedding.load(model_filename)

    return embedding


def getCallbacks():
    callbacks = []

    # if model_filename is not None:
    #    checkpointer = ModelCheckpoint(model_filename,
    #                                   monitor='val_loss',
    #                                   save_best_only=True,
    #                                   save_weights_only=True)
    #    callbacks.append(checkpointer)

    callbacks.append(ReduceLROnPlateau(monitor='loss',
                                       factor=0.2,
                                       patience=5,
                                       min_lr=1e-5))

    # if log_path is not None:
    #    freq = int(epochs / 20) + 1

    #    callbacks.append(TensorBoard(log_path, histogram_freq=freq,
    #                                 write_graph=False, write_grads=True,
    #                                 write_images=True, batch_size=batch_size))
    #    callbacks.append(CSVLogger(log_path + 'log.csv', append=True, separator=';'))

    # callbacks.append(TerminateOnNaN())
    # callbacks.append(EarlyStopping(monitor='val_loss', min_delta=1e-6, patience=50, mode='auto'))

    return callbacks
