import jsonlines
import os
import json
import random

from transformers import AutoTokenizer

import numpy as np

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-1M")
    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    # samples
    samples_training = []
    samples_dev = []
    for prefix in ["64k", "128k"]:
        samples_training_tmp = []
        samples_dev_tmp = []
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                valid_sample = True
                if line["question"] == "":
                    valid_sample = False
                if "database" in line["question"]:
                    valid_sample = False

                if valid_sample:
                    if line["label"] == "training":
                        samples_training_tmp.append(line)

                    if line["label"] == "dev":
                        samples_dev_tmp.append(line)
        samples_training.extend(samples_training_tmp)
        samples_dev.extend(samples_dev_tmp)
    print(len(samples_training), len(samples_dev))

    def process_data(samples):
        prompt_lengths = []
        message_data = []
        for sample in samples:
            valid_sample = True
            if sample["question"] == "":
                valid_sample = False
            if "database" in sample["question"]:
                valid_sample = False

            # sampled_answers = sample["sampled_answers"]
            # original_answer = sample["answer"]
            # sample_validities = []
            # for sampled_answer in sampled_answers:
            #     sampled_answer_output = sampled_answer["answer"]
            #     sampled_answer_output = sampled_answer_output.strip()
            #     sampled_answer_output = " ".join(sampled_answer_output.split())
            #     sampled_answer_output = ", ".join([tmp.strip() for tmp in sampled_answer_output.split(",")])
            #     if original_answer == sampled_answer_output:
            #         sample_validities.append(1)
            #     else:
            #         sample_validities.append(0)
            # if sum(sample_validities) == 0 or sum(sample_validities) == 8:
            #     valid_sample = False

            if valid_sample:
                answer = sample["answer"]

                question = sample["question"]
                markdowns = []
                for paper_id in sample["articles"]:
                    paper = articles_all[paper_id]
                    markdowns.append(paper["markdown"])
                context = "\n".join(markdowns)
                instruction = open("../reasoning_instruction.txt").read()
                instruction = instruction.replace("<question>", question)
                prompt = instruction.replace("<articles>", context)

                # messages = [{"role": "user", "content": prompt},
                #             {"role": "assistant", "content": answer}]
                # sample = {"messages": messages}
                messages = [{"role": "user", "content": prompt}]

                # text = tokenizer.apply_chat_template(
                #     messages,
                #     tokenize=False,
                #     add_generation_prompt=True
                #     )
                # tokenized_text = tokenizer.encode(text)

                sample = {"messages": messages, "answer": answer}

                # if len(tokenized_text) > 92160:
                #     prompt_lengths.append(len(tokenized_text))
                #     message_data.append(sample)
                message_data.append(sample)

        # print("prompt_lengths", prompt_lengths)
        return message_data, prompt_lengths

    train_message_data, _ = process_data(samples_training)
    dev_message_data, _ = process_data(samples_dev)
    print(len(train_message_data), len(dev_message_data))

    random.shuffle(train_message_data)
    train_message_data = train_message_data[:200]
    random.shuffle(dev_message_data)
    dev_message_data = dev_message_data[:50]
    print(len(train_message_data), len(dev_message_data))

    if not os.path.exists("debug_data"):
        os.mkdir("debug_data")
    with jsonlines.open("debug_data/train.jsonl", "w") as writer:
        writer.write_all(train_message_data)
    with jsonlines.open("debug_data/val.jsonl", "w") as writer:
        writer.write_all(dev_message_data)


