#!/usr/bin/env python3
"""
MATH preprocessing for verl multi-turn "Ver@K retry with verifier feedback".

This script writes:
  <local_save_dir>/train.parquet
  <local_save_dir>/test.parquet

Output schema (per row):
  - data_source: str
  - prompt: List[{"role": str, "content": str}]
  - ability: str                             ("math")
  - reward_model: {"style": "rule", "ground_truth": str}
  - extra_info:
      - split: str                           ("train" | "test")
      - index: int
      - interaction_kwargs:
          - name: str                        (your interaction name)
          - query: str                       (question text you want interaction to see)
          - ground_truth: str                (answer string)
          - max_attempts: int                (K)
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Optional

import datasets

from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed


DEFAULT_DATA_SOURCE = "DigitalLearningGmbH/MATH-lighteval"
DEFAULT_ABILITY = "math"
DEFAULT_INSTRUCTION = "Let's think step by step and put the final answer in \\boxed{}."


def _extract_math_final_answer(solution_field: str) -> str:
    if not isinstance(solution_field, str):
        raise TypeError(f"Expected MATH 'solution' to be str, got {type(solution_field)}")

    boxed = last_boxed_only_string(solution_field)
    if boxed is None:
        raise ValueError(f"Could not extract boxed answer from MATH solution field: {solution_field!r}")
    return remove_boxed(boxed).strip()


def _build_prompt_messages(
    question: str,
    instruction: str,
    system_prompt: Optional[str] = None,
) -> list[dict]:
    q = (question or "").strip()
    if not q:
        raise ValueError("Empty problem encountered in MATH example.")

    user_content = f"{q} {instruction}".strip()

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_content})
    return messages


def _make_map_fn(
    *,
    split: str,
    data_source: str,
    ability: str,
    instruction: str,
    system_prompt: Optional[str],
    interaction_name: str,
    k_max_attempts: int,
):
    def _process(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
        question = example["problem"]
        solution = example["solution"]

        ground_truth = _extract_math_final_answer(solution)
        prompt = _build_prompt_messages(question, instruction, system_prompt=system_prompt)

        row: Dict[str, Any] = {
            "data_source": data_source,
            "prompt": prompt,
            "ability": ability,
            "reward_model": {
                "style": "rule",
                "ground_truth": ground_truth,
            },
            "extra_info": {
                "split": split,
                "index": idx,
                "interaction_kwargs": {
                    "name": interaction_name,
                    "query": prompt[-1]["content"],
                    "ground_truth": ground_truth,
                    "max_attempts": int(k_max_attempts),
                },
            },
        }
        return row

    return _process


def _maybe_limit_dataset(ds: datasets.Dataset, max_samples: int, seed: Optional[int]) -> datasets.Dataset:
    if max_samples is None or max_samples <= 0:
        return ds
    if max_samples >= len(ds):
        return ds
    if seed is not None:
        ds = ds.shuffle(seed=seed)
    return ds.select(range(max_samples))


def _load_dataset(dataset_name: str, dataset_config: Optional[str], local_dataset_path: Optional[str]):
    if local_dataset_path:
        return datasets.load_dataset(local_dataset_path)
    if dataset_config:
        return datasets.load_dataset(dataset_name, dataset_config)
    return datasets.load_dataset(dataset_name)


def main():
    parser = argparse.ArgumentParser(description="Preprocess MATH for verl Ver@K retry interaction training.")
    parser.add_argument("--local_save_dir", type=str, default=None,
                        help="Local output directory for train.parquet and test.parquet. "
                        "Defaults to ./data/math_ver_k_retry_k<k_max_attempts>.")
    parser.add_argument("--hdfs_save_dir", type=str, default=None,
                        help="Optional HDFS output dir. If set, parquet files are copied there too.")
    parser.add_argument("--data_source", type=str, default=DEFAULT_DATA_SOURCE,
                        help="Value stored in the 'data_source' column. Keep the MATH source name for reward routing.")
    parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATA_SOURCE,
                        help="HF dataset name to load. For MATH, usually 'DigitalLearningGmbH/MATH-lighteval'.")
    parser.add_argument("--dataset_config", type=str, default=None,
                        help="HF dataset config/subset if needed.")
    parser.add_argument("--local_dataset_path", type=str, default=None,
                        help="Local path to raw dataset. Overrides dataset_name/config when set.")
    parser.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION,
                        help="Instruction appended to each problem (controls output format, e.g., \\boxed{} answer).")
    parser.add_argument("--system_prompt", type=str, default=None,
                        help="Optional system prompt. If not set, we only use a user message.")
    parser.add_argument("--interaction_name", type=str, default="ver_k_retry",
                        help="Must match the 'name' your interaction registry/config expects.")
    parser.add_argument("--k_max_attempts", type=int, default=4,
                        help="K: maximum retries/turns for your Ver@K interaction agent.")
    parser.add_argument("--num_proc", type=int, default=8,
                        help="Multiprocessing workers for datasets.map.")
    parser.add_argument("--max_train_samples", type=int, default=0,
                        help="If >0, subsample train split to this many examples.")
    parser.add_argument("--max_test_samples", type=int, default=0,
                        help="If >0, subsample test split to this many examples.")
    parser.add_argument("--seed", type=int, default=None,
                        help="Seed used only when subsampling with shuffle.")
    parser.add_argument("--keep_original_columns", action="store_true",
                        help="If set, keep original MATH columns (problem/solution) in output parquet.")
    args = parser.parse_args()

    default_save_dir = f"./data/math_ver_k_retry_k{args.k_max_attempts}"
    local_save_dir = os.path.expanduser(args.local_save_dir or default_save_dir)
    os.makedirs(local_save_dir, exist_ok=True)

    if args.k_max_attempts <= 0:
        raise ValueError("--k_max_attempts must be >= 1")

    raw = _load_dataset(args.dataset_name, args.dataset_config, args.local_dataset_path)

    train_ds = raw["train"]
    test_ds = raw["test"]

    train_ds = _maybe_limit_dataset(train_ds, args.max_train_samples, args.seed)
    test_ds = _maybe_limit_dataset(test_ds, args.max_test_samples, args.seed)

    map_train = _make_map_fn(
        split="train",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
    )
    map_test = _make_map_fn(
        split="test",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
    )

    remove_train_cols = [] if args.keep_original_columns else train_ds.column_names
    remove_test_cols = [] if args.keep_original_columns else test_ds.column_names

    train_out = train_ds.map(
        map_train,
        with_indices=True,
        num_proc=max(1, args.num_proc),
        remove_columns=remove_train_cols,
        desc="Processing MATH train split",
    )
    test_out = test_ds.map(
        map_test,
        with_indices=True,
        num_proc=max(1, args.num_proc),
        remove_columns=remove_test_cols,
        desc="Processing MATH test split",
    )

    train_path = os.path.join(local_save_dir, "train.parquet")
    test_path = os.path.join(local_save_dir, "test.parquet")

    train_out.to_parquet(train_path)
    test_out.to_parquet(test_path)

    print(f"[OK] Wrote: {train_path} ({len(train_out)} rows)")
    print(f"[OK] Wrote: {test_path} ({len(test_out)} rows)")

    if args.hdfs_save_dir:
        try:
            from verl.utils.hdfs_io import copy as hdfs_copy
            from verl.utils.hdfs_io import makedirs as hdfs_makedirs
        except Exception as e:
            raise RuntimeError(
                "You set --hdfs_save_dir but verl.utils.hdfs_io could not be imported. "
                "Run inside verl repo/env or remove --hdfs_save_dir."
            ) from e

        hdfs_makedirs(args.hdfs_save_dir)
        hdfs_copy(src=train_path, dst=os.path.join(args.hdfs_save_dir, "train.parquet"))
        hdfs_copy(src=test_path, dst=os.path.join(args.hdfs_save_dir, "test.parquet"))
        print(f"[OK] Copied parquet files to HDFS dir: {args.hdfs_save_dir}")

    print("\nSample row (train[0]):")
    print(train_out[0])


if __name__ == "__main__":
    main()
