import jsonlines
import os
import json
import random
from transformers import AutoTokenizer
import numpy as np

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-1M")

    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    # samples
    samples_training = []
    samples_dev = []
    for prefix in ["64k", "128k"]:
        samples_training_tmp = []
        samples_dev_tmp = []
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                valid_sample = True
                if line["question"] == "":
                    valid_sample = False
                if "database" in line["question"]:
                    valid_sample = False

                if valid_sample:
                    answer = line["answer"]

                    question = line["question"]
                    markdowns = []
                    for paper_id in line["articles"]:
                        paper = articles_all[paper_id]
                        markdowns.append(paper["markdown"])
                    context = "\n".join(markdowns)
                    instruction = open("../../test_full/full_instruction.txt").read()
                    instruction = instruction.replace("<question>", question)
                    prompt = instruction.replace("<articles>", context)

                    messages = [{"role": "user", "content": prompt},
                                {"role": "assistant", "content": answer}]
                    sample = {"messages": messages}

                    if line["label"] == "training":
                        samples_training_tmp.append(sample)

                    if line["label"] == "dev":
                        samples_dev_tmp.append(sample)

        print(prefix, f"training={len(samples_training_tmp)}, dev={len(samples_dev_tmp)}")
        samples_training.extend(samples_training_tmp)
        samples_dev.extend(samples_dev_tmp)

    random.shuffle(samples_training)
    random.shuffle(samples_dev)
    samples_dev = random.sample(samples_dev, 300)
    print(len(samples_training), len(samples_dev))
    # 9000 300
    with jsonlines.open("train.jsonl", "w") as writer:
        writer.write_all(samples_training)

    with jsonlines.open("dev.jsonl", "w") as writer:
        writer.write_all(samples_dev)


