import tools
import os
from tqdm import tqdm
SAMPLE_NUM = 20

if __name__ == "__main__":
    datasets = [
        {"datasets--allenai--reward-bench-2": [
            "train_NOANS", "train_ANS", "validation_NOANS"]},
        {"datasets--RUC-NLPIR--FlashRAG_datasets@hotpotqa_RAG": [
            "train_NOANS", "train_ANS", "validation_NOANS"]}
    ]
    models = tools.BACKBONE_MODELS

    for dataset in datasets:
        for dataset_name, dataset_subsets in dataset.items():
            for dataset_subset in dataset_subsets:
                reasoning_input_file_path = f"../data/{dataset_name}/{dataset_subset}.jsonl".replace(
                    "_ANS", "_NOANS")
                reasoning_input_jsons = tools.read_jsonl(
                    reasoning_input_file_path)

                for model in models:
                    print(dataset_name, dataset_subset, model)
                    reasoning_output_dir_path = f"../output/{dataset_name}/{dataset_subset}/{model}"

                    if not os.path.exists(
                            f"{reasoning_output_dir_path}/output_{SAMPLE_NUM-1}.jsonl"):
                        continue
                    judge_input_dir_path = f"../judge/input/{dataset_name}/{dataset_subset}/{model}/"

                    if os.path.exists(os.path.join(
                            judge_input_dir_path, "judge_input.jsonl")):
                        print(
                            f"Judge input jsonl already exists: {judge_input_dir_path}/judge_input.jsonl, skipping...")
                        continue

                    reasoning_output_sample_jsons = []
                    for idx in tqdm(range(SAMPLE_NUM)):

                        tmp = tools.read_jsonl(
                            f"{reasoning_output_dir_path}/output_{idx}.jsonl")

                        reasoning_output_sample_jsons.append(tmp)
                    if not os.path.exists(judge_input_dir_path):
                        os.makedirs(judge_input_dir_path)
                    tools.check_existence(os.path.join(
                        judge_input_dir_path, "judge_input.jsonl"))

                    judge_input_jsons = []
                    for idx, reasoning_input_json in tqdm(zip(
                            range(len(reasoning_input_jsons)),
                            reasoning_input_jsons)):
                        original_id = reasoning_input_json['id']
                        rank_id = idx
                        question = reasoning_input_json['content']

                        answer = reasoning_input_json['answer']
                        for sample_id in range(SAMPLE_NUM):
                            sample_response = reasoning_output_sample_jsons[sample_id][idx]['text'].strip(
                            )
                            if not sample_response.strip().startswith("<think>"):
                                sample_response = "<think> \n"+sample_response
                            judge_input_json = {
                                "id": f"{original_id}_{sample_id}_0",
                                "rank_id": f"{rank_id}_{sample_id}_0",
                                "assembly_question": question,
                                "assembly_reasoning": sample_response,
                                "assembly_answer": answer
                            }
                            tools.write_jsonl(judge_input_json, os.path.join(

                                judge_input_dir_path, "judge_input.jsonl"))
