import os
import json
import random
import datasets
import re


def is_json_subset(test_json_path, calib_json_path):

    with open(test_json_path, "r") as f:
        test_data = json.load(f)

    with open(calib_json_path, "r") as f:
        calib_data = json.load(f)

    test_questions = {item["question"] for item in test_data}
    calib_questions = {item["question"] for item in calib_data}

    is_subset = calib_questions.issubset(test_questions)

    missing = calib_questions - test_questions
    missing_count = len(missing)

    print(f"Is calibration JSON a subset of test JSON? {is_subset}")
    print(f"Number of questions in calibration.json not in test.json: {missing_count}")

    return is_subset, missing_count

def create_gsm_symbolic_json(output_dir, calib_ratio=0.1, seed=42):
    """
    Loads the GSM dataset and creates a json file for a specific split
    in the same format as the livebench data.
    """
    dataset = datasets.load_dataset("apple/GSM-Symbolic", "main")  # 12500
    
    splits = dataset["test"].train_test_split(test_size=calib_ratio, seed=seed)
    calibration_data = splits["test"]
    test_data = splits["train"]

    # 1. create calibration set: 1250
    final_calibration = []
    for item in calibration_data:
        m = re.search(r"(?:####|###)\s*([^\n]+)\s*$", item["answer"])
        if m:
            ground_truth = m.group(1).strip()
        else:
            continue

        final_calibration.append({
            "id": item["id"],
            "instance": item["instance"],
            "question": item["question"],
            "answer": item["answer"],
            "ground_truth": ground_truth
        })

    random.seed(42)
    random.shuffle(final_calibration)
    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, "calibration_gsm_symbolic.json"), "w") as f:
        json.dump(final_calibration, f, indent=4)

    # 2. create test set: 11250
    final_test = []
    for item in test_data:
        m = re.search(r"(?:####|###)\s*([^\n]+)\s*$", item["answer"])
        if m:
            ground_truth = m.group(1).strip()
        else:
            continue

        final_test.append({
            "id": item["id"],
            "instance": item["instance"],
            "question": item["question"],
            "answer": item["answer"],
            "ground_truth": ground_truth
        })

    random.shuffle(final_test)
    with open(os.path.join(output_dir, "test_gsm_symbolic.json"), "w") as f:
        json.dump(final_test, f, indent=4)

if __name__ == "__main__":
    create_gsm_symbolic_json("data/gsm_symbolic")
    is_json_subset("data/gsm_symbolic/test_gsm_symbolic.json", "data/gsm_symbolic/calibration_gsm_symbolic.json")

