from vllm import LLM, SamplingParams
import jsonlines
from tqdm import tqdm
from transformers import AutoTokenizer
import os
import sys
import re

sys.path.append("../../")
from bench.dataset.data_loading import load_test, load_articles, get_full_texts, get_titles

if __name__ == "__main__":
    COMPILED_REGEX = re.compile(r"\\boxed\{(.*?)\}")
    dataset_dir = "../../bench"
    # model_name_official = "Qwen/Qwen2.5-7B-Instruct-1M" # the fine-tuned model

    model_name_official = "checkpoints_saved/sampled/global_step192_hf"
    generation_folder = "generations_sampled_192"

    model_name_save = "qwen25_7b_instruct_1m_grpo"
    target_mode = "test_full"
    vllm_tensor_parallel_size = 4

    # sample_level = "64k"
    # vllm_max_model_length = 86016
    # sample_level = "128k"
    # vllm_max_model_length = 233472
    # sample_level = "512k"
    # vllm_max_model_length = 729088
    sample_level = "1024k"
    vllm_max_model_length = 1010000

    if not os.path.exists(generation_folder):
        os.makedirs(generation_folder)
    save_name = f"{generation_folder}/{sample_level}_{target_mode}_{model_name_save}.jsonl"
    if os.path.exists(save_name):
        samples_test = []
        with jsonlines.open(save_name) as reader:
            for line in reader:
                samples_test.append(line)
        print("existing results loaded", len(samples_test))
    else:
        samples_test = load_test(prefix=sample_level, samples_folder=dataset_dir + "/dataset/samples/final/")
        print("original samples loaded", len(samples_test))

    llm = LLM(model=model_name_official, max_model_len=vllm_max_model_length,
              tensor_parallel_size=vllm_tensor_parallel_size)
    #  max_tokens is for the maximum length for generation.
    sampling_params = SamplingParams(n=3, temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=10240)

    articles_all = load_articles(articles_folder=dataset_dir + "/article/")

    tokenizer = AutoTokenizer.from_pretrained(model_name_official)
    print(tokenizer.chat_template)
    truncated = 0
    for sample_index, sample in tqdm(enumerate(samples_test), total=len(samples_test),
                                     desc=f"{model_name_save}_{sample_level}"):
        if "generations" not in sample.keys():
            question = sample["question"]
            markdowns = get_full_texts(sample, articles_all)
            context = "\n".join(markdowns)
            instruction = open("../reasoning_instruction.txt").read()
            instruction = instruction.replace("<question>", question)

            # truncate the input texts when model context is larger than the sample level
            if sample_level == "1024k":
                model_max_window = 1010000
                tokenized_instruction = tokenizer.encode(instruction)

                # the max generation size is set to the number of title tokens in all input articles
                titles = get_titles(sample, articles_all)
                tokenized_titles = tokenizer.encode(", ".join(titles))
                input_size = model_max_window - len(tokenized_titles) * 2 - len(tokenized_instruction)

                tokenized_context = tokenizer.encode(context)
                if len(tokenized_context) > input_size - 1024: # give more tokens to reasoning
                    print("tokenized_instruction", len(tokenized_instruction))
                    print("tokenized_context", len(tokenized_context))
                    truncated += 1
                    context = tokenizer.decode(tokenized_context[:input_size - 1024])
            prompt_content = instruction.replace("<articles>", context)

            conversation = [{"role": "user", "content": prompt_content}]
            text = tokenizer.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
                )
            conversation_outputs = llm.generate([text], sampling_params, use_tqdm=False)
            # print(conversation_outputs)
            generations = []
            reasonings = []
            for conversation_output in conversation_outputs:
                for tmp in conversation_output.outputs:
                    matches = COMPILED_REGEX.findall(tmp.text)
                    answer = matches[-1] if matches else ""
                    generations.append(answer)
                    reasonings.append(tmp.text)
                    print(tmp.text)

            sample["generations"] = generations
            sample["reasonings"] = reasonings
            print([sample["answer"]] + sample["generations"])
            samples_test[sample_index] = sample
            if sample_index % 10 == 0:
                with jsonlines.open(save_name, "w") as writer:
                    writer.write_all(samples_test)

    print("truncated", truncated)
    with jsonlines.open(save_name, "w") as writer:
        writer.write_all(samples_test)
