import jsonlines
import os
import json
import random

import numpy as np

if __name__ == "__main__":
    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    sql_types = set([])
    for prefix in ["64k", "128k"]:
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                sql_types.add(line["sql_type"])
    print(sql_types)
    valid_sql_types = ['multi_ran_filtering_foa', 'multi_ran_organizing', 'multi_simple', 'multi_ran_filtering_ofo', 'multi_ran_aggregating', 'multi_ran_filtering_foo']
    print(valid_sql_types)


    # samples
    samples_training = []
    samples_dev = []
    for prefix in ["64k", "128k"]:
        samples_training_tmp = []
        samples_dev_tmp = []
        with jsonlines.open(os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
            for line in reader:
                valid_sample = True
                if line["question"] == "":
                    valid_sample = False
                if "database" in line["question"]:
                    valid_sample = False

                if line["sql_type"] not in valid_sql_types:
                    valid_sample = False

                if valid_sample:
                    answer = line["answer"]

                    question = line["question"]
                    markdowns = []
                    for paper_id in line["articles"]:
                        paper = articles_all[paper_id]
                        markdowns.append(paper["markdown"])
                    context = "\n".join(markdowns)
                    instruction = open("../../test_full/full_instruction.txt").read()
                    instruction = instruction.replace("<question>", question)
                    prompt = instruction.replace("<articles>", context)

                    messages = [{"role": "user", "content": prompt},
                                {"role": "assistant", "content": answer}]
                    sample = {"messages": messages}

                    if line["label"] == "training":
                        samples_training_tmp.append(sample)

                    if line["label"] == "dev":
                        samples_dev_tmp.append(sample)

        print(prefix, f"training={len(samples_training_tmp)}, dev={len(samples_dev_tmp)}")
        samples_training.extend(samples_training_tmp)
        samples_dev.extend(samples_dev_tmp)

    random.shuffle(samples_training) # 5912
    random.shuffle(samples_dev) # 324
    print(len(samples_training), len(samples_dev))
    samples_dev = random.sample(samples_dev, 200)
    print(len(samples_training), len(samples_dev))



    with jsonlines.open("train.jsonl", "w") as writer:
        writer.write_all(samples_training)

    with jsonlines.open("dev.jsonl", "w") as writer:
        writer.write_all(samples_dev)


