import os
from tokenizers import Tokenizer
import numpy as np
from util import load_models, sample_tokens
import jax
import random
import string


def test_talker_independency(reasoner_snap, talker_snap, input_text_to_reasoner, input_text_to_talker, replace_reasoner_output_with_noise: bool=False):
    tk = Tokenizer.from_file('tokenizer_250725.json')
    reasoner, talker = load_models(reasoner_snap, talker_snap, vocab_size=30000, size='large')

    talker_generation_steps = len(tk.encode(input_text_to_talker).ids)
    reasoner_input_tokens = tk.encode(input_text_to_reasoner).ids
    reasoner_token_ids = np.array([reasoner_input_tokens], dtype=np.int32)
    reasoner_output = reasoner.reason(reasoner.embed(reasoner_token_ids))
    if replace_reasoner_output_with_noise:
        reasoner_output = jax.random.normal(shape=reasoner_output.shape, dtype=reasoner_output.dtype, key=jax.random.PRNGKey(0))

    talker_input_tokens = tk.encode(input_text_to_talker).ids[:10]
    for i in range(talker_generation_steps - 11):
        logits  = talker(reasoner_output, input_tokens=np.array([talker_input_tokens], dtype=np.int32))
        tokens = sample_tokens(logits, temperature=0.8, top_k=40, top_p=0.95, min_p=0.05, repetition_penalty=1.1)
        os.system('cls' if os.name == 'nt' else 'clear')
        generated = tokens.tolist()[0][-1]
        talker_input_tokens.append(generated)
        decoded = tk.decode(talker_input_tokens)
        print('|=============================|\n')
        print(decoded)
        print('\n|=============================|')


if __name__ == '__main__':
    reasoner_random_string = ''.join(random.SystemRandom().choice(string.ascii_lowercase + string.digits) for _ in range(64))
    talker_random_string = ''.join(random.SystemRandom().choice(string.ascii_lowercase + string.digits) for _ in range(32))

    input_text_reasoner = [
        "Francis Bacon was an English philosopher and statesman who served as Attorney General and Lord Chancellor of England under King James I. Bacon argued for the importance of natural philosophy, guided by the scientific, his works remained influential",
        talker_random_string
    ]

    input_text_talker = [
        "Francis Bacon was an English philosopher and statesman who served as Attorney General and Lord Chancellor of England under King James I. Bacon argued for the importance of natural philosophy, guided by the scientific, his works remained influential",
        "Jean-Paul Sartre was a French philosopher, political activist, biographer, and literary critic. Sartre was one of the key figures in the philosophy of existentialism (and phenomenology).",
        talker_random_string
    ]

    # Normal Situation: Clean input latent from Reasoner
    test_talker_independency('snapshot/natural_language/reasoner/sst/sst',
                'snapshot/natural_language/talker/adaptive/train',
                             input_text_reasoner[0],
                             input_text_talker[0],
                             replace_reasoner_output_with_noise=False)

    # Gaussian noise as input latent
    test_talker_independency('snapshot/natural_language/reasoner/sst/sst',
                'snapshot/natural_language/talker/adaptive/train',
                             input_text_reasoner[0],
                             input_text_talker[0],
                             replace_reasoner_output_with_noise=True)

    # Random string in Reasoner
    test_talker_independency('snapshot/natural_language/reasoner/sst/sst',
                'snapshot/natural_language/talker/adaptive/train',
                             input_text_reasoner[-1],
                             input_text_talker[0],
                             replace_reasoner_output_with_noise=False)

    # Random initial token for talker
    test_talker_independency('snapshot/natural_language/reasoner/sst/sst',
                'snapshot/natural_language/talker/adaptive/train',
                             input_text_reasoner[0],
                             input_text_talker[-1],
                             replace_reasoner_output_with_noise=False)

    # Semantic mismatch
    test_talker_independency('snapshot/natural_language/reasoner/sst/sst',
                'snapshot/natural_language/talker/adaptive/train',
                             input_text_reasoner[0],
                             input_text_talker[1],
                             replace_reasoner_output_with_noise=False)