from scripts.utils import load_single_dataset, save_dataset
import pandas as pd
import os
import datasets
import argparse

SAVE_COLUMNS = ["data_source", "prompt", "ability", "reward_model", "extra_info"]

SYSTEM_PROMPT = (
    "\nWhen tackling complex reasoning tasks, you have access to the following actions. "
    "Use them as needed to progress through your thought process.\n\n"
    "[ASSESS]\n\n[ADVANCE]\n\n[VERIFY]\n\n[SIMPLIFY]\n\n[SYNTHESIZE]\n\n[PIVOT]\n\n[OUTPUT]\n\n"
    "You should strictly follow the format below:\n\n"
    "[ACTION NAME]\n\n"
    "# Your action step 1\n\n"
    "# Your action step 2\n\n"
    "# Your action step 3\n\n"
    "...\n\n"
    "Next action: [NEXT ACTION NAME]\n\n"
)

LAST_SENTENCE = "\n\nPresent the answer in LaTex format: \\boxed{Your answer}"
VALID_TAGS = ["amc", "aime", "math500", "minerva", "olympiadbench", "svamp", "gsm8k", "asdiv", "multiarith"]

def deal_with_amc(row):
    row["data_source"] = "aimo-amc"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "amc-valid"}
    return row

def deal_with_aime(row):
    row["data_source"] = "aimo-aime2024"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "aime-valid"}
    return row

def deal_with_math(row):
    row["data_source"] = "math-500"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["problem"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["expected_answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "math-500"}  # 去掉多余空格
    return row

def deal_with_minerva(row):
    row["data_source"] = "minerva-math"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "minerva-math"}
    return row

def _first_answer(value):
    """更稳健地取最终答案：兼容 list / str / 数值。"""
    if isinstance(value, list) and value:
        return value[0]
    return value

def deal_with_olympiad(row):
    row["data_source"] = "olympiad-bench"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    gt = _first_answer(row.get("final_answer"))
    row["reward_model"] = {"ground_truth": str(gt), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "olympiad-bench"}
    return row

def deal_with_svamp(row):
    prompt= ""
    if row["Body"][-1] != "." and row["Body"][-1].isalpha():
        prompt = row["Body"] + "." + "\n"
    else:
        prompt = row["Body"] + "\n"
    prompt = prompt + row["Question"] + LAST_SENTENCE
    row["data_source"] = "svamp"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": prompt, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["Answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["ID"]), "split": "svamp"}
    return row

def deal_with_gsm8k(row):
    row["data_source"] = "gsm8k"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["problem"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["answer"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "gsm8k"}
    return row

def deal_with_asdiv(row):
    row["data_source"] = "asdiv"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["body"] + row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["answer"].split(" (")[0]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "asdiv"}
    return row

def deal_with_multiarith(row):
    row["data_source"] = "multiarith"
    row["prompt"] = [
        {"content": SYSTEM_PROMPT, "role": "system"},
        {"content": row["question"] + LAST_SENTENCE, "role": "user"},
    ]
    row["ability"] = "math"
    row["reward_model"] = {"ground_truth": str(row["final_ans"]), "style": "rule"}
    row["extra_info"] = {"index": str(row["id"]), "split": "multiarith"}
    return row

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Merge JSONL or JSON files.')
    parser.add_argument('--valid-tags', type=str, help='Comma-separated list')
    parser.add_argument('--output-file', type=str,  help='Path to the output merged file')
    args = parser.parse_args()

    valid_tags = args.valid_tags.split(",")
    for valid_tag in valid_tags:
        if valid_tag not in VALID_TAGS:
            raise RuntimeError(f"TAG {valid_tag} not in {VALID_TAGS}")

    parts = []

    # AI-MO AMC
    if "amc" in valid_tags:
        amc_ds = load_single_dataset(
            "~/prime_eval/data/AI-MO/aimo-validation-amc/aimo-validation-amc.jsonl"
        )
        amc_ds = amc_ds.map(deal_with_amc)
        amc_ds = amc_ds.remove_columns([c for c in amc_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(amc_ds)

    # AI-MO AIME
    if "aime" in valid_tags:
        aime_ds = load_single_dataset(
            "~/prime_eval/data/AI-MO/aimo-validation-aime/aimo-validation-aime.jsonl"
        )
        aime_ds = aime_ds.map(deal_with_aime)
        aime_ds = aime_ds.remove_columns([c for c in aime_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(aime_ds)

    # MATH-500
    if "math500" in valid_tags:
        math_df = pd.read_json(
            os.path.join("~/prime_eval/data/math500/math_test_cleaned.json")
        )
        math_ds = datasets.Dataset.from_list(math_df.to_dict(orient="records"))
        # 若原文件无 id 字段，可补一个
        if "id" not in math_ds.column_names:
            math_ds = math_ds.add_column("id", list(range(len(math_ds))))
        math_ds = math_ds.map(deal_with_math)
        math_ds = math_ds.remove_columns([c for c in math_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(math_ds)

    # Minerva Math
    if "minerva" in valid_tags:
        minerva_ds = load_single_dataset("~/datasets/math-ai-minervamath/test.jsonl")
        if "id" not in minerva_ds.column_names:
            minerva_ds = minerva_ds.add_column("id", list(range(len(minerva_ds))))
        minerva_ds = minerva_ds.map(deal_with_minerva)
        minerva_ds = minerva_ds.remove_columns([c for c in minerva_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(minerva_ds)

    # OlympiadBench
    if "olympiadbench" in valid_tags:
        olympiad_ds = load_single_dataset("~/prime_eval/data/olympiadbench/test.jsonl")
        if "id" not in olympiad_ds.column_names:
            olympiad_ds = olympiad_ds.add_column("id", list(range(len(olympiad_ds))))
        olympiad_ds = olympiad_ds.map(deal_with_olympiad)
        olympiad_ds = olympiad_ds.remove_columns([c for c in olympiad_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(olympiad_ds)

    # svamp
    if "svamp" in valid_tags:
        svamp_ds = load_single_dataset("~/datasets/ChilleD-SVAMP/test.json")
        if "id" not in svamp_ds.column_names:
            svamp_ds = svamp_ds.add_column("id", list(range(len(svamp_ds))))
        svamp_ds = svamp_ds.map(deal_with_svamp)
        svamp_ds = svamp_ds.remove_columns([c for c in svamp_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(svamp_ds)

    # gsm8k
    if "gsm8k" in valid_tags:
        gsm8k_ds = load_single_dataset("~/datasets/pss0204-gsm8k_test/gsm8k_test/test-00000-of-00001.parquet", "train")
        if "id" not in gsm8k_ds.column_names:
            gsm8k_ds = gsm8k_ds.add_column("id", list(range(len(gsm8k_ds))))
        gsm8k_ds = gsm8k_ds.map(deal_with_gsm8k)
        gsm8k_ds = gsm8k_ds.remove_columns([c for c in gsm8k_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(gsm8k_ds)

    # asdiv
    if "asdiv" in valid_tags:
        asdiv_ds = load_single_dataset("~/datasets/EleutherAI-asdiv/asdiv/validation-00000-of-00001.parquet", "train")
        if "id" not in asdiv_ds.column_names:
            asdiv_ds = asdiv_ds.add_column("id", list(range(len(asdiv_ds))))
        asdiv_ds = asdiv_ds.map(deal_with_asdiv)
        asdiv_ds = asdiv_ds.remove_columns([c for c in asdiv_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(asdiv_ds)

    # multiarith
    if "multiarith" in valid_tags:
        multiarith_ds = load_single_dataset("~/datasets/ChilleD-MultiArith/test.json")
        if "id" not in multiarith_ds.column_names:
            multiarith_ds = multiarith_ds.add_column("id", list(range(len(multiarith_ds))))
        multiarith_ds = multiarith_ds.map(deal_with_multiarith)
        multiarith_ds = multiarith_ds.remove_columns([c for c in multiarith_ds.column_names if c not in SAVE_COLUMNS])
        parts.append(multiarith_ds)

    # 合并并保存
    validation_ds = datasets.concatenate_datasets(parts)
    save_dataset(validation_ds, args.output_file)


# ~/verl_cs/.conda/bin/python ~/verl_cs/scripts/construct_validation.py --valid-tags svamp,gsm8k,asdiv,multiarith --output-file ~/LLaMA-Factory-250514/saves_shuyan/prime_math_easy_valid.parquet