import os
import json
import random
import datasets
from collections import defaultdict


def is_json_subset(test_json_path, calib_json_path):

    with open(test_json_path, "r") as f:
        test_data = json.load(f)

    with open(calib_json_path, "r") as f:
        calib_data = json.load(f)

    test_questions = {item["question"] for item in test_data}
    calib_questions = {item["question"] for item in calib_data}

    is_subset = calib_questions.issubset(test_questions)

    missing = calib_questions - test_questions
    missing_count = len(missing)

    print(f"Is calibration JSON a subset of test JSON? {is_subset}")
    print(f"Number of questions in calibration.json not in test.json: {missing_count}")

    return is_subset, missing_count

def create_gsm_json(output_dir, calib_ratio=0.1, seed=42):
    """
    Loads the GSM dataset and creates a json file for a specific split
    in the same format as the livebench data.
    """

    dataset = datasets.load_dataset("qintongli/GSM-Plus")
    test_data = list(dataset["test"])  # 10552

    # group samples by perturbation_type
    perturbation_groups = defaultdict(list)
    for item in test_data:
        perturbation_groups[item["perturbation_type"]].append(item)

    # shuffle the calibration and test datasets
    random.seed(seed)
    calibration_data = []
    remaining_test_data = []

    for ptype, items in perturbation_groups.items():
        random.shuffle(items)
        n_calib = max(1, int(len(items) * calib_ratio))
        calibration_data.extend(items[:n_calib])
        remaining_test_data.extend(items[n_calib:])

    # 1. create calibration set: 1048
    final_calibration = [{
        "question": item["question"],
        "solution": item["solution"],
        "ground_truth": item["answer"],
        "perturbation_type": item["perturbation_type"]
    } for item in calibration_data]

    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, "calibration_gsm.json"), "w") as f:
        json.dump(final_calibration, f, indent=4)

    # 2. create test set: 9504
    final_test = [{
        "question": item["question"],
        "solution": item["solution"],
        "ground_truth": item["answer"],
        "perturbation_type": item["perturbation_type"]
    } for item in remaining_test_data]

    with open(os.path.join(output_dir, "test_gsm.json"), "w") as f:
        json.dump(final_test, f, indent=4)

if __name__ == "__main__":
    create_gsm_json("data/gsm_plus")
    is_json_subset("data/gsm_plus/test_gsm.json", "data/gsm_plus/calibration_gsm.json")
