import os, logging, json, shutil
from transformers import AutoTokenizer, TFAutoModelForCausalLM, AutoConfig
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Softmax
from tensorflow.keras.models import Model
import numpy as np
# from codecarbon import OfflineEmissionsTracker

from GenericTools.LeanguageTreatmentTools.nlp import Vocabulary
from GenericTools.StayOrganizedTools.VeryCustomSacred import CustomExperiment
from ariel_tests.models.lmArielEmbedding import LmArielEmbedding

FILENAME = os.path.realpath(__file__)
CDIR = os.path.dirname(FILENAME)

random_string = ''.join([str(r) for r in np.random.choice(10, 4)])
ex = CustomExperiment(random_string + '-gia', base_dir=CDIR, seed=11)
logger = logging.getLogger('mylogger')


@ex.config
def config():
    model_name = 'gpt2'  # microsoft/DialoGPT-small gpt2
    latent_dim = [3]  # [1, 5, 10, 50, 100]
    size_lat_dim = 1
    is_encode_decode = False
    is_generate = True
    n_generations = 2
    maxlen = 10  # 100


@ex.automain
def main(latent_dim, size_lat_dim, is_encode_decode, is_generate, model_name, n_generations, maxlen):
    exp_dir = os.path.join(CDIR, ex.observers[0].basedir)
    text_dir = os.path.join(exp_dir, 'text')

    # tracker = OfflineEmissionsTracker(country_iso_code="CAN", output_dir=other_dir)
    # tracker.start()

    MODELPATH = os.path.join(CDIR, 'data', model_name.replace('/', '_'))
    TOKPATH = os.path.join(CDIR, 'data', model_name.replace('/', '_') + '_tokenizer')
    if not os.path.isdir(TOKPATH): os.mkdir(TOKPATH)
    if not os.path.isdir(MODELPATH): os.mkdir(MODELPATH)

    CONFIGPATH = os.path.join(MODELPATH, 'config.json')

    if not os.path.isfile(CONFIGPATH):
        # model = TFGPT2LMHeadModel.from_pretrained(model, output_attentions=False)
        # tokenizer = GPT2Tokenizer.from_pretrained(model)

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = TFAutoModelForCausalLM.from_pretrained(model_name)

        model.save_pretrained(MODELPATH)
        tokenizer.save_pretrained(TOKPATH)
        config = AutoConfig.from_pretrained(model_name)
        config.save_pretrained(TOKPATH)

    else:
        model = TFAutoModelForCausalLM.from_pretrained(MODELPATH)
        tokenizer = AutoTokenizer.from_pretrained(TOKPATH)

    input_sentence = Input((None,))
    sentence = Lambda(lambda x: tf.cast(x, tf.int32))(input_sentence)
    logits = model(sentence).logits[:, -1, :]
    output = Softmax()(logits)
    gpt2 = Model(input_sentence, output)

    tokens = list(tokenizer.get_vocab().keys())  # [:100]

    # vocabulary = Vocabulary(tokens=tokens) #, special_tokens=tokens[-1])
    vocabulary = Vocabulary(tokens=tokens, special_tokens=tokens[-1])

    if is_encode_decode:

        assert not isinstance(latent_dim, list)
        # Get grammar and vocabulary from oracle
        embedding = LmArielEmbedding(vocabulary=vocabulary, latent_dim=latent_dim, language_model=gpt2)
        # embedding = LmArielEmbedding(vocabulary=vocabulary, latent_dim=latent_dim, size_lat_dim=size_lat_dim)

        # choose sentences to arielize
        encoded_input = tokenizer("Hello, I'm a single sentence! Hello, Hello Hello, Hellooooo!!")
        sentences = [
            vocabulary.indicesToSentence(encoded_input['input_ids']),
            '<|endoftext|> _ assert Ċ Ċ That Ġonly Ġhappens Ġin Ġsolitude Ċ Ċ _ if Ċ Ċ not Ċ Ċ have Ċ Ċ no Ġmind Ġof Ġwhat',
            '<|endoftext|> ' + vocabulary.indicesToSentence(encoded_input['input_ids']),
        ]
        # sentences = ['H i d e a r ! H i d e a r ! H i d e a r ! H i d e a r !']
        # sentences = ['¥ ¥ ¥ ¥']
        # sentences = ['0 0 0 0 ']
        indeices = []
        for sentence in sentences:
            id = vocabulary.sentenceToIndices(sentence)
            indeices.append(id)

        print('\nEncode!')
        encoder = embedding.getEncoder()
        code = encoder.encode(sentences)
        print(code)

        print('\nDecode!')
        decoder = embedding.getDecoder()
        prediction = decoder.decode(code, remove_special_tokens=True)  # random_code

        for o, p in zip(sentences, prediction):
            print('\ntarget:       ', o)
            print('prediction:   ', p)

    if is_generate:
        if not isinstance(latent_dim, list): latent_dim = [latent_dim]
        generations = {}
        for l in latent_dim:
            generations['d={}'.format(l)] = []
            # Get grammar and vocabulary from oracle
            embedding = LmArielEmbedding(vocabulary=vocabulary, latent_dim=l, language_model=gpt2, maxlen=maxlen)
            # embedding = LmArielEmbedding(vocabulary=vocabulary, latent_dim=latent_dim, size_lat_dim=size_lat_dim)

            print('\nGenerate!')
            random_code = np.random.rand(n_generations, l)
            decoder = embedding.getDecoder()
            prediction = decoder.decode(random_code, return_indices=True)  # random_code
            unique_rows = np.unique(prediction, axis=0)

            for p in prediction:
                sentence = tokenizer.decode(p)
                print(r'generation:   {}'.format(repr(sentence)))
                generations['d={}'.format(l)].append(repr(sentence).replace('<|endoftext|>', ''))
            print('generation diversity: {}\%'.format(unique_rows.shape[0] / prediction.shape[0]))

        generations_path = os.path.join(text_dir, 'generations.json')

        with open(generations_path, "w") as f:
            json.dump(generations, f)

        del generations
        with open(generations_path) as f:
            generations = json.load(f)
        print(generations)

        print('\n\n')
        for k, sentences in generations.items():
            print(k)
            for sentence in sentences:
                print(sentence)

    shutil.make_archive(exp_dir, 'zip', exp_dir)
