import os
import json
import random
from datasets import load_dataset

def create_nasa_history_json(output_dir, calibration_size=1000, seed=42):

    dataset = load_dataset("patrickfleith/NASA-History-MCQ")  # 7469
    full_data = list(dataset["train"])
    random.seed(seed)
    #random.shuffle(full_data)

    calibration_data = full_data[:calibration_size]  # 1000
    test_data = full_data[calibration_size:]  # 6469

    # 1. create calibration set
    final_calibration = []
    for item in calibration_data:
        question_text = item["question"]
        options = [
            item["correct_answer"],
            item["incorrect_answer_1"],
            item["incorrect_answer_2"],
            item["incorrect_answer_3"]
        ]
        random.shuffle(options)
        answer_idx = options.index(item["correct_answer"])
        answer_letter = chr(65 + answer_idx)

        options_str = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
        full_question = f"{question_text}\n{options_str}"
        final_calibration.append({
            "question": full_question,
            "ground_truth": answer_letter,
        })

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(os.path.join(output_dir, "calibration_nasa_history.json"), "w") as f:
        json.dump(final_calibration, f, indent=4)

    # 2. create test set
    final_test = []
    for item in test_data:
        question_text = item["question"]
        options = [
            item["correct_answer"],
            item["incorrect_answer_1"],
            item["incorrect_answer_2"],
            item["incorrect_answer_3"]
        ]
        random.shuffle(options)
        answer_idx = options.index(item["correct_answer"])
        answer_letter = chr(65 + answer_idx)

        options_str = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
        full_question = f"{question_text}\n{options_str}"
        final_test.append({
            "question": full_question,
            "ground_truth": answer_letter,
        })

    with open(os.path.join(output_dir, "test_nasa_history.json"), "w") as f:
        json.dump(final_test, f, indent=4)

def filter_json_by_category(input_file, output_dir, output_file, category):

    with open(input_file, "r") as f:
        data = json.load(f)

    filtered_data = [item for item in data if item.get("category") == category]

    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, output_file), "w") as f:
        json.dump(filtered_data, f, indent=4)

    print(f"Filtering complete. Saved {len(filtered_data)} items to {os.path.join(output_dir, output_file)}")

    return filtered_data

if __name__ == "__main__":
    create_nasa_history_json("data/nasa_history", calibration_size=1000)
