import os
from datetime import timedelta

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

from transformers import AutoTokenizer, TFAutoModelForCausalLM, AutoConfig

from GenericTools.KerasTools.esoteric_layers import AddLossLayer
from GenericTools.KerasTools.esoteric_models.transformer import GPT
from GenericTools.KerasTools.esoteric_models.wizard_of_wikipedia import tf_ContextKnowledgeEncoder, metrics_wow, \
    UniversalSentenceEmbedding
from GenericTools.LeanguageTreatmentTools.random_language import random_indices
from GenericTools.StayOrganizedTools.utils import timeStructured
from ariel_tests.models.AriEL.AriEL_decoder import *
from ariel_tests.models.AriEL.AriEL_encoder import *

CDIR = os.path.dirname(os.path.realpath(__file__))
DATADIR = os.path.abspath(os.path.join(CDIR, '..', '..', 'data'))


def EndToEndModelPretrainedGPT2AriEL(num_layers=5, d_model=256, num_heads=2, dff=512, input_vocab_size=int(5e4),
                                     target_vocab_size=int(5e4), max_pos=1024, rate=.1, max_knowledge=5, pad_idx=0):
    cke = tf_ContextKnowledgeEncoder(num_layers, d_model, num_heads, dff, input_vocab_size, max_pos, rate, pad_idx)

    MODELPATH = os.path.join(DATADIR, 'gpt2')
    CONFIGPATH = os.path.join(MODELPATH, 'config.json')
    if not os.path.isdir(MODELPATH): os.mkdir(MODELPATH)

    if not os.path.isfile(CONFIGPATH):
        model = TFAutoModelForCausalLM.from_pretrained('gpt2')
        model.save_pretrained(MODELPATH)
    else:
        model = TFAutoModelForCausalLM.from_pretrained(MODELPATH)

    input_sentence = Input((None,))
    sentence = Lambda(lambda x: tf.cast(x, tf.int32))(input_sentence)
    logits = model(sentence).logits[:, -1, :]
    output = Softmax()(logits)
    gpt2 = Model(input_sentence, output)

    decoder = ArielDecoderLayer2(
        lat_dim=d_model,
        maxlen=max_pos,
        language_model=gpt2,
        PAD=pad_idx,
        output_type='both')

    src_tokens = Input((None,))
    tgt_tokens = Input((None,))
    know_tokens = Input((max_knowledge, None))
    chosen_knowledge = Input((1,))
    e, m, _ = cke([src_tokens, know_tokens, chosen_knowledge])
    m = Lambda(lambda x: tf.squeeze(x, [1, 2]))(m)
    code = UniversalSentenceEmbedding()([e, m])
    logits = decoder(code)
    model = tf.keras.models.Model([src_tokens, know_tokens, chosen_knowledge, tgt_tokens], logits[1])
    return model


class MixLogitsLayer(tf.keras.layers.Layer):

    def call(self, inputs, *args, **kwargs):
        logits_test, logits_train = inputs
        random_switch_1 = tf.cast(tf.random.uniform((), minval=0, maxval=2, dtype=tf.dtypes.int32), tf.float32)
        # random_switch_2 = tf.cast(tf.random.uniform((), minval=0, maxval=2, dtype=tf.dtypes.int32), tf.float32)
        random_switch_2 = 1 - random_switch_1

        logits_sum = random_switch_1 * logits_test + random_switch_2 * logits_train
        lp = tf.keras.backend.learning_phase()
        logits = lp * logits_sum + (1 - lp) * logits_test

        return logits


def E2E_AriEL(flag=''):
    # possible flags: '', 'with_lm', 'with_enc'
    def build_model(num_layers=5, d_model=256, num_heads=2, dff=512, input_vocab_size=int(5e4),
                    target_vocab_size=int(5e4), encoder_maxlen=1024, decoder_maxlen=1024, rate=.1, max_knowledge=5,
                    pad_idx=0):
        cke = tf_ContextKnowledgeEncoder(num_layers, d_model, num_heads, dff, input_vocab_size, rate=rate,
                                         pad_idx=pad_idx)

        gpt_layer = GPT(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff,
                        target_vocab_size=input_vocab_size, maximum_position_encoding=decoder_maxlen, pad_idx=pad_idx,
                        rate=rate)
        input_sentence = Input((None,))
        sentence = Lambda(lambda x: tf.cast(x, tf.int32))(input_sentence)
        logits = gpt_layer(sentence, output_type='embedding_projection')[:, -1, :]
        output = Softmax()(logits)
        gpt2_ariel = Model(input_sentence, output)

        input_sentence = Input((None,))
        sentence = Lambda(lambda x: tf.cast(x, tf.int32))(input_sentence)
        logits = gpt_layer(sentence, output_type='embedding_projection')
        output = Softmax()(logits)
        gpt2_lm = Model(input_sentence, output)

        decoder = ArielDecoderLayer2(
            lat_dim=d_model,
            maxlen=decoder_maxlen,
            language_model=gpt2_ariel,
            PAD=pad_idx,
            output_type='both')

        src_tokens = Input((None,))
        tgt_tokens = Input((None,))
        know_tokens = Input((max_knowledge, None))
        chosen_knowledge = Input((1,))
        e, m, _ = cke([src_tokens, know_tokens, chosen_knowledge])
        m = Lambda(lambda x: tf.squeeze(x, [1, 2]))(m)
        code = UniversalSentenceEmbedding()([e, m])

        if 'with_enc' in flag:
            encoder = ArielEncoderLayer1(
                lat_dim=d_model,
                language_model=gpt2_ariel,
                PAD=pad_idx,
                maxlen=decoder_maxlen - 1
            )

            code_too = encoder(tgt_tokens[:, 1:])
            code = AddLossLayer(loss=tf.keras.losses.MSE, coef=.1)([code_too, code])

        if 'sigmoid' in flag:
            code = tf.keras.layers.Activation('sigmoid')(code)
        elif 'layernorm' in flag:
            code = (tf.keras.layers.LayerNormalization()(code)+1)/2

        logits = decoder(code)[1]

        if 'with_lm' in flag:
            logits_gpt2 = gpt2_lm(tgt_tokens)
            logits = MixLogitsLayer()([logits, logits_gpt2])

        model = tf.keras.models.Model([src_tokens, know_tokens, chosen_knowledge, tgt_tokens], logits)
        return model

    return build_model


def quick_test():
    tf.compat.v1.enable_eager_execution()
    pad_idx = 7
    max_knowledge = 5
    vocab_size = 10  # int(5e4)
    maxlen = 3
    batch_size = 11

    model = E2E_AriEL('with_lm')(max_knowledge=max_knowledge, input_vocab_size=vocab_size, pad_idx=pad_idx,
                                 max_pos=maxlen)

    src_tokens = random_indices(vocab_size, pad_idx=pad_idx, batch_size=batch_size, maxlen=4)
    tgt_tokens = random_indices(vocab_size, pad_idx=pad_idx, maxlen=maxlen, batch_size=batch_size)
    know_tokens = tf.concat([random_indices(vocab_size, pad_idx=pad_idx, batch_size=batch_size, maxlen=9)[:, None]
                             for _ in range(max_knowledge)], axis=1)

    chosen_knowledge = random_indices(max_knowledge, maxlen=1, batch_size=batch_size)
    input_tensors = [src_tokens, know_tokens, chosen_knowledge, tgt_tokens]

    print('Outputs shapes: ')
    output = model(input_tensors)
    print(output.shape)
    prediction = model.predict(input_tensors)
    print(prediction.shape)

    model.compile(
        'SGD', tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=metrics_wow(num_classes=vocab_size, mask_value=pad_idx)
    )
    model.fit(x=input_tensors, y=tgt_tokens, epochs=3, steps_per_epoch=1)


def test_ariel():
    maxlen = 30
    pad_idx = 7
    input_vocab_size = 10
    src_tokens = random_indices(input_vocab_size, pad_idx=pad_idx, maxlen=maxlen, padding='post')

    gpt_layer = GPT(num_layers=2, d_model=2, num_heads=2, dff=2, target_vocab_size=input_vocab_size,
                    maximum_position_encoding=maxlen, pad_idx=pad_idx, rate=0.1)
    input_sentence = Input((None,))
    # sentence = Lambda(lambda x: tf.cast(x, tf.int32))(input_sentence)
    logits = gpt_layer(input_sentence, output_type='embedding_projection')[:, -1, :]
    output = Softmax()(logits)
    gpt2 = Model(input_sentence, output)

    encoder = ArielEncoderLayer1(
        language_model=gpt2,
        PAD=pad_idx,
        maxlen=maxlen
    )

    decoder = ArielDecoderLayer2(
        language_model=gpt2,
        PAD=pad_idx,
        maxlen=maxlen,
        output_type='tokens',
    )

    _, time_1 = timeStructured(False, True)
    code = encoder(tf.constant(src_tokens))
    _, time_2 = timeStructured(False, True)
    duration_encoding = timedelta(seconds=time_2 - time_1)
    decoded = decoder(code)
    _, time_3 = timeStructured(False, True)
    duration_decoding = timedelta(seconds=time_3 - time_2)
    print(code)
    print(src_tokens)
    print(decoded)
    print('Encoding took {}'.format(duration_encoding))
    print('Decoding took {}'.format(duration_decoding))


if __name__ == '__main__':
    quick_test()
    # test_ariel()
