import jsonlines
import os
import json
import random
from tqdm import tqdm
from transformers import AutoTokenizer


if __name__ == "__main__":
    # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-1M")
    # # articles
    # with open("../../bench/article/papers_final.json") as f:
    #     articles_all = json.load(f)
    #
    # valid_length_levels = ["64k", "128k"]
    # valid_sql_types = ['multi_ran_filtering_foa', 'multi_ran_organizing', 'multi_simple', 'multi_ran_filtering_ofo',
    #                    'multi_ran_aggregating', 'multi_ran_filtering_foo']
    # valid_sql_complexity_threshold = 15
    # valid_question_focuses = ['author_list', 'title_word_count', 'title_entire', 'author_count', 'author_relationship']
    #
    # # samples
    # samples_training = []
    # samples_dev = []
    # for prefix in valid_length_levels:
    #     samples_training_tmp = []
    #     samples_dev_tmp = []
    #     with jsonlines.open(
    #             os.path.join("../../bench/dataset/samples/final/", f"{prefix}_samples_target.jsonl")) as reader:
    #         for line in reader:
    #             valid_sample = True
    #             if line["question"] == "":
    #                 valid_sample = False
    #             if "database" in line["question"]:
    #                 valid_sample = False
    #
    #             if line["answer"] == "NULL":
    #                 valid_sample = False
    #
    #             if valid_sample:
    #                 question_focus = line["focus"]
    #                 sql_type = line["sql_type"]
    #                 sql = line["sql"]
    #                 if len(sql.split()) <= valid_sql_complexity_threshold and question_focus in valid_question_focuses and sql_type in valid_sql_types:
    #                     if line["label"] == "training":
    #                         samples_training_tmp.append(line)
    #
    #                     if line["label"] == "dev":
    #                         samples_dev_tmp.append(line)
    #     samples_training.extend(samples_training_tmp)
    #     samples_dev.extend(samples_dev_tmp)
    # print(len(samples_training), len(samples_dev))
    #
    # all_train = samples_training
    # all_dev = samples_dev
    #
    # def process_data(samples):
    #     message_data = []
    #     for sample in tqdm(samples):
    #         answer = sample["answer"]
    #
    #         question = sample["question"]
    #         markdowns = []
    #         for paper_id in sample["articles"]:
    #             paper = articles_all[paper_id]
    #             markdowns.append(paper["markdown"])
    #         context = "\n".join(markdowns)
    #         instruction = open("../reasoning_instruction.txt").read()
    #         instruction = instruction.replace("<question>", question)
    #         prompt = instruction.replace("<articles>", context)
    #
    #         # messages = [{"role": "user", "content": prompt},
    #         #             {"role": "assistant", "content": answer}]
    #         # sample = {"messages": messages}
    #
    #         messages = [{"role": "user", "content": prompt}]
    #
    #         text = tokenizer.apply_chat_template(
    #             conversation=messages,
    #             tokenize=False,
    #             add_generation_prompt=True
    #             )
    #         tokenized_text = tokenizer.encode(text)
    #
    #         if len(tokenized_text) <= 101532: # only on Spartan
    #             tmp = {"messages": messages, "answer": answer}
    #             message_data.append(tmp)
    #     return message_data
    #
    # train_valid_data = process_data(all_train)
    # dev_valid_data = process_data(all_dev)
    # print(len(train_valid_data), len(dev_valid_data))
    # # train_message_data = random.sample(train_valid_data, 2000)
    # train_message_data = train_valid_data
    # dev_message_data = random.sample(dev_valid_data, 100)
    #
    # random.shuffle(train_message_data)
    # random.shuffle(dev_message_data)
    # print(len(train_message_data), len(dev_message_data)) # 3200, 100
    #
    # if not os.path.exists("raw_data"):
    #     os.mkdir("raw_data")
    # with jsonlines.open("raw_data/train.jsonl", "w") as writer:
    #     writer.write_all(train_message_data)
    # with jsonlines.open("raw_data/val.jsonl", "w") as writer:
    #     writer.write_all(dev_message_data)

    tmp = []
    with jsonlines.open("raw_data/val.jsonl") as reader:
        for line in reader:
            tmp.append(line)
    print("tmp", len(tmp))

    tmp = []
    with jsonlines.open("raw_data/train.jsonl") as reader:
        for line in reader:
            tmp.append(line)
    print("tmp", len(tmp))

    tmp = []
    with jsonlines.open("sampled_data/val.jsonl") as reader:
        for line in reader:
            tmp.append(line)
    print("tmp", len(tmp))

    tmp = []
    with jsonlines.open("sampled_data/train.jsonl") as reader:
        for line in reader:
            tmp.append(line)
    print("tmp", len(tmp))


