import jsonlines
import os
import json
import random

from transformers import AutoTokenizer

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-1M")
    # articles
    with open("../../bench/article/papers_final.json") as f:
        articles_all = json.load(f)

    # samples
    samples_train_processed = []
    with jsonlines.open("../processing/train_original.jsonl") as reader:
        for line in reader:
            samples_train_processed.append(line)

    samples_dev_processed = []
    with jsonlines.open("../processing/dev_original.jsonl") as reader:
        for line in reader:
            samples_dev_processed.append(line)

    all_train = samples_train_processed
    all_dev = samples_dev_processed
    print(len(all_train), len(all_dev)) # 4154, 234 (filtering based on dimensions)


    def process_data(samples):
        message_data = []
        invalid_data = []
        for sample in samples:
            if sample["answer"] != "NULL":
                sampled_answers = sample["sampled_answers"]
                original_answer = sample["answer"]
                sample_validities = []
                for sampled_answer in sampled_answers:
                    sampled_answer_output = sampled_answer["answer"]
                    sampled_answer_output = sampled_answer_output.strip()
                    sampled_answer_output = " ".join(sampled_answer_output.split())
                    sampled_answer_output = ", ".join([tmp.strip() for tmp in sampled_answer_output.split(",")])
                    if original_answer == sampled_answer_output:
                        sample_validities.append(1)
                    else:
                        sample_validities.append(0)

                answer = sample["answer"]
                question = sample["question"]
                markdowns = []
                for paper_id in sample["articles"]:
                    paper = articles_all[paper_id]
                    markdowns.append(paper["markdown"])
                context = "\n".join(markdowns)
                instruction = open("../reasoning_instruction.txt").read()
                instruction = instruction.replace("<question>", question)
                prompt = instruction.replace("<articles>", context)

                messages = [{"role": "user", "content": prompt}]
                tmp = {"messages": messages, "answer": answer}

                text = tokenizer.apply_chat_template(
                    conversation=messages,
                    tokenize=False,
                    add_generation_prompt=True
                    )
                tokenized_text = tokenizer.encode(text)

                if len(tokenized_text) <= 101532: # only on Spartan
                    if sum(sample_validities) > 0 and sum(sample_validities) < 8:
                        message_data.append(tmp)
                    else:
                        invalid_data.append(tmp)

        return message_data, invalid_data


    train_valid_data, train_invalid_data = process_data(all_train)
    dev_valid_data, dev_invalid_data = process_data(all_dev)
    print(len(train_valid_data), len(dev_valid_data))

    train_message_data = train_valid_data + train_invalid_data
    dev_message_data = dev_valid_data + dev_invalid_data
    dev_message_data = random.sample(dev_message_data, 100)
    print(len(train_message_data), len(dev_message_data)) # 3200, 100

    if not os.path.exists("sampled_data"):
        os.mkdir("sampled_data")
    with jsonlines.open("sampled_data/train.jsonl", "w") as writer:
        writer.write_all(train_message_data)
    with jsonlines.open("sampled_data/val.jsonl", "w") as writer:
        writer.write_all(dev_message_data)
