import jsonlines
import os
import re
import json
import random
from vllm import LLM, SamplingParams
from tqdm import tqdm
from transformers import AutoTokenizer

if __name__ == "__main__":
    COMPILED_REGEX = re.compile(r"\\boxed\{(.*?)\}")

    valid_length_levels = ["64k", "128k"]
    valid_sql_types = ['multi_ran_filtering_foa', 'multi_ran_organizing', 'multi_simple', 'multi_ran_filtering_ofo',
                       'multi_ran_aggregating', 'multi_ran_filtering_foo']
    valid_sql_complexity_threshold = 15
    valid_question_focuses = ['author_list', 'title_word_count', 'title_entire', 'author_count', 'author_relationship']
    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    # samples
    samples_training = []
    samples_dev = []
    for prefix in valid_length_levels:
        samples_training_tmp = []
        samples_dev_tmp = []
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                valid_sample = True
                if line["question"] == "":
                    valid_sample = False
                if "database" in line["question"]:
                    valid_sample = False

                if valid_sample:
                    question_focus = line["focus"]
                    sql_type = line["sql_type"]
                    sql = line["sql"]
                    if len(sql.split()) <= valid_sql_complexity_threshold and question_focus in valid_question_focuses and sql_type in valid_sql_types:
                        if line["label"] == "training":
                            samples_training_tmp.append(line)

                        if line["label"] == "dev":
                            samples_dev_tmp.append(line)

        print(prefix, f"training={len(samples_training_tmp)}, dev={len(samples_dev_tmp)}")
        samples_training.extend(samples_training_tmp)
        samples_dev.extend(samples_dev_tmp)

    print("all data", len(samples_training), len(samples_dev))

    # sampling answers
    model_name_official = "Qwen/Qwen2.5-7B-Instruct-1M" # the original model
    llm = LLM(model=model_name_official, max_model_len=141312, tensor_parallel_size=4)
    #  max_tokens is for the maximum length for generation.
    sampling_params = SamplingParams(n=8, temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=10240)
    tokenizer = AutoTokenizer.from_pretrained(model_name_official)
    print(tokenizer.chat_template)

    def sampling_answers(samples, tmp_save):
        for sample_index, sample in tqdm(enumerate(samples), total=len(samples),
                                         desc=f"sampling"):
            question = sample["question"]
            markdowns = []
            for paper_id in line["articles"]:
                paper = articles_all[paper_id]
                markdowns.append(paper["markdown"])
            context = "\n".join(markdowns)
            instruction = open("../reasoning_instruction.txt").read()
            instruction = instruction.replace("<question>", question)
            prompt = instruction.replace("<articles>", context)

            conversation = [{"role": "user", "content": prompt}]
            text = tokenizer.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=True
                )
            conversation_outputs = llm.generate([text], sampling_params, use_tqdm=False)
            # print(conversation_outputs)
            samplings = []
            for conversation_output in conversation_outputs:
                for tmp in conversation_output.outputs:
                    matches = COMPILED_REGEX.findall(tmp.text)
                    answer = matches[-1] if matches else ""
                    samplings.append({"answer": answer, "reasoning": tmp.text})

            sample["sampled_answers"] = samplings
            # print("original answer:", sample["answer"])
            # for sampled_answer in sample["sampled_answers"]:
            #     print("sampled answer:", sampled_answer["answer"])
                # print("reasoning traces", sampled_answer["reasoning"])
            samples[sample_index] = sample
            if sample_index % 10 == 0:
                with jsonlines.open(tmp_save, "w") as writer:
                    writer.write_all(samples)

        with jsonlines.open(tmp_save, "w") as writer:
            writer.write_all(samples)

    sampling_answers(samples_training, "train_original.jsonl")
    sampling_answers(samples_dev, "dev_original.jsonl")