import jsonlines
import os
import json
import random

import numpy as np

if __name__ == "__main__":
    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    question_focuses = set([])
    for prefix in ["64k", "128k"]:
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                question_focuses.add(line["focus"])
    print(question_focuses)
    valid_question_focuses = ['author_list', 'title_word_count', 'title_entire', 'author_count', 'author_relationship']
    print(valid_question_focuses)


    # samples
    samples_training = []
    samples_dev = []
    for prefix in ["64k", "128k"]:
        samples_training_tmp = []
        samples_dev_tmp = []
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                valid_sample = True
                if line["question"] == "":
                    valid_sample = False
                if "database" in line["question"]:
                    valid_sample = False

                if line["focus"] not in valid_question_focuses:
                    valid_sample = False

                if valid_sample:
                    answer = line["answer"]

                    question = line["question"]
                    markdowns = []
                    for paper_id in line["articles"]:
                        paper = articles_all[paper_id]
                        markdowns.append(paper["markdown"])
                    context = "\n".join(markdowns)
                    instruction = open("../../test_full/full_instruction.txt").read()
                    instruction = instruction.replace("<question>", question)
                    prompt = instruction.replace("<articles>", context)

                    messages = [{"role": "user", "content": prompt},
                                {"role": "assistant", "content": answer}]
                    sample = {"messages": messages}

                    if line["label"] == "training":
                        samples_training_tmp.append(sample)

                    if line["label"] == "dev":
                        samples_dev_tmp.append(sample)

        print(prefix, f"training={len(samples_training_tmp)}, dev={len(samples_dev_tmp)}")
        samples_training.extend(samples_training_tmp)
        samples_dev.extend(samples_dev_tmp)

    random.shuffle(samples_training) # 6015
    random.shuffle(samples_dev) # 338
    print(len(samples_training), len(samples_dev))
    # samples_training = random.sample(samples_training, 1000)
    samples_dev = random.sample(samples_dev, 200)
    print(len(samples_training), len(samples_dev))



    with jsonlines.open("train.jsonl", "w") as writer:
        writer.write_all(samples_training)

    with jsonlines.open("dev.jsonl", "w") as writer:
        writer.write_all(samples_dev)


