#!/usr/bin/env python3
"""
Omni-MATH preprocessing for verl multi-turn "Ver@K retry".

HF dataset:
  KbsdJames/Omni-MATH (test split ~4.43k) with columns including problem/solution/answer. 
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Optional

import datasets

DEFAULT_DATA_SOURCE = "KbsdJames/Omni-MATH"
DEFAULT_ABILITY = "math"
DEFAULT_INSTRUCTION = "Let's think step by step and put the final answer in \\boxed{}."


def _build_prompt_messages(problem: str, instruction: str, system_prompt: Optional[str] = None) -> list[dict]:
    p = (problem or "").strip()
    if not p:
        raise ValueError("Empty problem encountered.")
    user_content = f"{p} {instruction}".strip()
    msgs = []
    if system_prompt:
        msgs.append({"role": "system", "content": system_prompt})
    msgs.append({"role": "user", "content": user_content})
    return msgs


def _maybe_limit_dataset(ds: datasets.Dataset, max_samples: int, seed: Optional[int]) -> datasets.Dataset:
    if not max_samples or max_samples <= 0 or max_samples >= len(ds):
        return ds
    if seed is not None:
        ds = ds.shuffle(seed=seed)
    return ds.select(range(max_samples))


def _filter_skipped(ds: datasets.Dataset, num_proc: int) -> datasets.Dataset:
    if "__skip__" not in ds.column_names:
        return ds
    ds = ds.filter(lambda x: not x["__skip__"], num_proc=max(1, num_proc))
    return ds.remove_columns(["__skip__"])


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--local_save_dir", type=str, default=None)
    ap.add_argument("--data_source", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--dataset_name", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION)
    ap.add_argument("--system_prompt", type=str, default=None)
    ap.add_argument("--interaction_name", type=str, default="ver_k_retry")
    ap.add_argument("--k_max_attempts", type=int, default=4)
    ap.add_argument("--num_proc", type=int, default=8)
    ap.add_argument("--max_test_samples", type=int, default=0)
    ap.add_argument("--seed", type=int, default=None)
    ap.add_argument("--keep_original_columns", action="store_true")
    args = ap.parse_args()

    out_dir = os.path.expanduser(args.local_save_dir or f"./data/omni_math_ver_k_retry_k{args.k_max_attempts}")
    os.makedirs(out_dir, exist_ok=True)

    raw = datasets.load_dataset(args.dataset_name)
    # Omni-MATH is commonly provided as test-only
    test_ds = _maybe_limit_dataset(raw["test"], args.max_test_samples, args.seed)

    def _process(ex: Dict[str, Any], idx: int) -> Dict[str, Any]:
        problem = ex["problem"]
        answer = (ex["answer"] or "").strip()
        prompt = _build_prompt_messages(problem, args.instruction, system_prompt=args.system_prompt)
        if not answer:
            return {
                "__skip__": True,
                "data_source": args.data_source,
                "prompt": prompt,
                "ability": DEFAULT_ABILITY,
                "reward_model": {"style": "rule", "ground_truth": ""},
                "extra_info": {
                    "split": "test",
                    "index": idx,
                    "interaction_kwargs": {
                        "name": args.interaction_name,
                        "query": prompt[-1]["content"],
                        "ground_truth": "",
                        "max_attempts": int(args.k_max_attempts),
                    },
                },
            }
        return {
            "__skip__": False,
            "data_source": args.data_source,
            "prompt": prompt,
            "ability": DEFAULT_ABILITY,
            "reward_model": {"style": "rule", "ground_truth": answer},
            "extra_info": {
                "split": "test",
                "index": idx,
                "interaction_kwargs": {
                    "name": args.interaction_name,
                    "query": prompt[-1]["content"],
                    "ground_truth": answer,
                    "max_attempts": int(args.k_max_attempts),
                },
            },
        }

    remove_cols = [] if args.keep_original_columns else test_ds.column_names
    test_out = test_ds.map(_process, with_indices=True, num_proc=max(1, args.num_proc), remove_columns=remove_cols)
    test_out = _filter_skipped(test_out, args.num_proc)

    test_path = os.path.join(out_dir, "test.parquet")
    test_out.to_parquet(test_path)
    print(f"[OK] Wrote: {test_path} ({len(test_out)} rows)")


if __name__ == "__main__":
    main()
