#!/usr/bin/env python3
"""
NuminaMath-CoT preprocessing for verl multi-turn "Ver@K retry".

Writes:
  <local_save_dir>/train.parquet
  <local_save_dir>/test.parquet

HF dataset:
  AI-MO/NuminaMath-CoT
Columns include: problem, solution (and others).  
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Optional

import datasets

from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed

DEFAULT_DATA_SOURCE = "AI-MO/NuminaMath-CoT"
DEFAULT_ABILITY = "math"
DEFAULT_INSTRUCTION = "Let's think step by step and put the final answer in \\boxed{}."


def _extract_final_answer(solution_field: str, strict: bool) -> Optional[str]:
    if not isinstance(solution_field, str):
        if strict:
            raise TypeError(f"Expected 'solution' to be str, got {type(solution_field)}")
        return None
    boxed = last_boxed_only_string(solution_field)
    if boxed is None:
        if strict:
            raise ValueError(f"Could not extract boxed answer from solution: {solution_field[:200]!r}...")
        return None
    try:
        return remove_boxed(boxed).strip()
    except Exception:
        if strict:
            raise
        return None


def _build_prompt_messages(question: str, instruction: str, system_prompt: Optional[str] = None) -> list[dict]:
    q = (question or "").strip()
    if not q:
        raise ValueError("Empty problem encountered.")
    user_content = f"{q} {instruction}".strip()
    msgs = []
    if system_prompt:
        msgs.append({"role": "system", "content": system_prompt})
    msgs.append({"role": "user", "content": user_content})
    return msgs


def _make_map_fn(
    *,
    split: str,
    data_source: str,
    ability: str,
    instruction: str,
    system_prompt: Optional[str],
    interaction_name: str,
    k_max_attempts: int,
    strict: bool,
):
    def _process(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
        problem = example["problem"]
        solution = example["solution"]

        gt = _extract_final_answer(solution, strict=strict)
        prompt = _build_prompt_messages(problem, instruction, system_prompt=system_prompt)
        if gt is None:
            # Mark for filtering, but keep schema consistent for datasets.map workers.
            return {
                "__skip__": True,
                "data_source": data_source,
                "prompt": prompt,
                "ability": ability,
                "reward_model": {"style": "rule", "ground_truth": ""},
                "extra_info": {
                    "split": split,
                    "index": idx,
                    "interaction_kwargs": {
                        "name": interaction_name,
                        "query": prompt[-1]["content"],
                        "ground_truth": "",
                        "max_attempts": int(k_max_attempts),
                    },
                },
            }

        return {
            "__skip__": False,
            "data_source": data_source,
            "prompt": prompt,
            "ability": ability,
            "reward_model": {"style": "rule", "ground_truth": gt},
            "extra_info": {
                "split": split,
                "index": idx,
                "interaction_kwargs": {
                    "name": interaction_name,
                    "query": prompt[-1]["content"],
                    "ground_truth": gt,
                    "max_attempts": int(k_max_attempts),
                },
            },
        }

    return _process


def _maybe_limit_dataset(ds: datasets.Dataset, max_samples: int, seed: Optional[int]) -> datasets.Dataset:
    if not max_samples or max_samples <= 0 or max_samples >= len(ds):
        return ds
    if seed is not None:
        ds = ds.shuffle(seed=seed)
    return ds.select(range(max_samples))


def _filter_skipped(ds: datasets.Dataset, num_proc: int) -> datasets.Dataset:
    if "__skip__" not in ds.column_names:
        return ds
    ds = ds.filter(lambda x: not x["__skip__"], num_proc=max(1, num_proc))
    return ds.remove_columns(["__skip__"])


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--local_save_dir", type=str, default=None)
    ap.add_argument("--data_source", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--dataset_name", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION)
    ap.add_argument("--system_prompt", type=str, default=None)
    ap.add_argument("--interaction_name", type=str, default="ver_k_retry")
    ap.add_argument("--k_max_attempts", type=int, default=4)
    ap.add_argument("--num_proc", type=int, default=8)
    ap.add_argument("--max_train_samples", type=int, default=0)
    ap.add_argument("--max_test_samples", type=int, default=0)
    ap.add_argument("--seed", type=int, default=None)
    ap.add_argument("--keep_original_columns", action="store_true")
    ap.add_argument("--strict", action="store_true", help="Fail if any row is missing a boxed answer.")
    args = ap.parse_args()

    out_dir = os.path.expanduser(args.local_save_dir or f"./data/numinamath_cot_ver_k_retry_k{args.k_max_attempts}")
    os.makedirs(out_dir, exist_ok=True)

    raw = datasets.load_dataset(args.dataset_name)
    train_ds = _maybe_limit_dataset(raw["train"], args.max_train_samples, args.seed)
    test_ds = _maybe_limit_dataset(raw["test"], args.max_test_samples, args.seed)

    map_train = _make_map_fn(
        split="train",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
        strict=args.strict,
    )
    map_test = _make_map_fn(
        split="test",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
        strict=args.strict,
    )

    remove_train_cols = [] if args.keep_original_columns else train_ds.column_names
    remove_test_cols = [] if args.keep_original_columns else test_ds.column_names

    train_out = train_ds.map(map_train, with_indices=True, num_proc=max(1, args.num_proc), remove_columns=remove_train_cols)
    test_out = test_ds.map(map_test, with_indices=True, num_proc=max(1, args.num_proc), remove_columns=remove_test_cols)

    train_out = _filter_skipped(train_out, args.num_proc)
    test_out = _filter_skipped(test_out, args.num_proc)

    train_path = os.path.join(out_dir, "train.parquet")
    test_path = os.path.join(out_dir, "test.parquet")
    train_out.to_parquet(train_path)
    test_out.to_parquet(test_path)

    print(f"[OK] Wrote: {train_path} ({len(train_out)} rows)")
    print(f"[OK] Wrote: {test_path} ({len(test_out)} rows)")


if __name__ == "__main__":
    main()
