import os
import json
from datasets import load_dataset
import re


def add_ground_truth(ex):
    ex["ground_truth"] = short_from_metamath(ex["response"])
    return ex

def short_from_metamath(ans: str):
    m = re.search(r"The answer is:\s*([^\n]+)", ans, re.IGNORECASE)
    if not m:
        raise ValueError(f"Cannot extract answer from: {ans!r}")
    return m.group(1).strip()

def short_from_gsm(ans: str):
    # GSM8K answer is after '####' at the end of the string
    m = re.search(r"####\s*([^\n]+)\s*$", ans.strip())
    if not m:
        raise ValueError(f"Cannot extract answer from: {ans!r}")
    return m.group(1).strip()

def attach_orig(ex, gsm_q2a):
    # add 'original_answer' by matching MetaMath's original_question to GSM8K
    q = ex["original_question"]
    return {"original_answer": gsm_q2a.get(q)[0], "original_ground_truth": gsm_q2a.get(q)[1]}
    
def create_metamath_json(output_dir, calib_ratio=0.2, seed=42):

    # load MetaMath-GSM240K
    mm = load_dataset("fxmeng/MetaMath-GSM240K")["train"]  # 240000
    mm = mm.map(add_ground_truth)

    # load GSM8K (train) and build a question->answer map
    gsm = load_dataset("gsm8k", "main")["train"]  # 14946
    gsm_q2a = {ex["question"]: [ex["answer"], short_from_gsm(ex["answer"])] for ex in gsm}
    mm_with_orig = mm.map(attach_orig, fn_kwargs={"gsm_q2a": gsm_q2a})

    splits = mm_with_orig.train_test_split(test_size=calib_ratio, seed=seed)

    # prepare calibraiton: 55180
    calibration = []
    seen_original = set()
    for item in splits["test"]:
        # add the modified item
        calibration.append({
            "source": "metamath",
            "type": item["type"],
            "original_question": item["original_question"],
            "question": item["query"],
            "answer": item["response"],
            "ground_truth": item["ground_truth"]            
        })

        # add the original item only once per original_question
        oq = item["original_question"]
        oa = item["original_ground_truth"]
        if oq and oa and oq not in seen_original:
            calibration.append({
                "source": "gsm",
                "question": oq,
                "answer": item["original_answer"],
                "ground_truth": oa,
            })
            seen_original.add(oq)

    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, "calibration_metamath.json"), "w") as f:
        json.dump(calibration[:1000], f, indent=4)
    
    # prepare test: 192210
    test = []
    for item in splits["train"]:
        # add the modified item
        test.append({
            "source": "metamath",
            "type": item["type"],
            "original_question": item["original_question"],
            "question": item["query"],
            "answer": item["response"],
            "ground_truth": item["ground_truth"]            
        })

        # add the original item only once per original_question
        oq = item["original_question"]
        oa = item["original_ground_truth"]
        if oq and oa and oq not in seen_original:
            test.append({
                "source": "gsm",
                "question": oq,
                "answer": item["original_answer"],
                "ground_truth": oa,
            })
            seen_original.add(oq)

    with open(os.path.join(output_dir, "test_metamath.json"), "w") as f:
        json.dump(test[:10000], f, indent=4)
        
    print(len(calibration[:1000]), len(test[:10000]))

if __name__ == "__main__":
    create_metamath_json("data/metamath")


