#!/usr/bin/env python3
"""
GSM8K preprocessing for verl multi-turn "Ver@K retry with verifier feedback".

This script writes:
  <local_save_dir>/train.parquet
  <local_save_dir>/test.parquet

Output schema (per row):
  - data_source: str                         (e.g., "openai/gsm8k")
  - prompt: List[{"role": str, "content": str}]
  - ability: str                             ("math")
  - reward_model: {"style": "rule", "ground_truth": str}
  - extra_info:
      - split: str                           ("train" | "test")
      - index: int
      - interaction_kwargs:
          - name: str                        (your interaction name)
          - query: str                       (question text you want interaction to see)
          - ground_truth: str                (answer string)
          - max_attempts: int                (K)

Why interaction_kwargs is inside extra_info:
  verl.utils.dataset.rl_dataset.RLHFDataset.__getitem__ pulls:
      interaction_kwargs = row_dict["extra_info"]["interaction_kwargs"]
  and then sets row_dict["interaction_kwargs"] for AgentLoop/rollout. (So store it in extra_info!)

Usage:
  python examples/data_preprocess/gsm8k_ver_k_retry.py \
    --local_save_dir ~/data/gsm8k_ver_k \
    --interaction_name ver_k_retry \
    --k_max_attempts 4 \
    --num_proc 8

Optional:
  --hdfs_save_dir hdfs://...
"""

from __future__ import annotations

import argparse
import os
import re
from typing import Any, Dict, Optional

import datasets


DEFAULT_DATA_SOURCE = "openai/gsm8k"
DEFAULT_DATA_CONFIG = "main"
DEFAULT_ABILITY = "math"
DEFAULT_INSTRUCTION = "Let's think step by step and put the final answer in \\boxed{}."


def _extract_gsm8k_final_answer(answer_field: str) -> str:
    """
    GSM8K "answer" field typically contains a worked solution ending with:
        #### <number>
    We extract <number> and normalize commas.

    Matches negative, decimal, and comma-formatted numbers.
    """
    if not isinstance(answer_field, str):
        raise TypeError(f"Expected GSM8K 'answer' to be str, got {type(answer_field)}")

    m = re.search(r"####\s*(-?[0-9\.,]+)", answer_field)
    if m is None:
        # Fail fast: you do NOT want silent bad ground_truth in RL training.
        raise ValueError(f"Could not extract final answer from GSM8K answer field: {answer_field!r}")

    return m.group(1).replace(",", "").strip()


def _build_prompt_messages(
    question: str,
    instruction: str,
    system_prompt: Optional[str] = None,
) -> list[dict]:
    """
    Construct HF-chat-style messages.

    By default:
      - optional system message
      - one user message containing the question + formatting instruction
    """
    q = (question or "").strip()
    if not q:
        raise ValueError("Empty question encountered in GSM8K example.")

    user_content = f"{q} {instruction}".strip()

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_content})
    return messages


def _make_map_fn(
    *,
    split: str,
    data_source: str,
    ability: str,
    instruction: str,
    system_prompt: Optional[str],
    interaction_name: str,
    k_max_attempts: int,
) :
    """
    Returns a function suitable for datasets.map(with_indices=True).

    We intentionally store interaction_kwargs inside extra_info:
      extra_info["interaction_kwargs"] = {...}
    because RLHFDataset reads it from there at __getitem__ time.
    """
    def _process(example: Dict[str, Any], idx: int) -> Dict[str, Any]:
        question = example["question"]
        answer = example["answer"]

        ground_truth = _extract_gsm8k_final_answer(answer)
        prompt = _build_prompt_messages(question, instruction, system_prompt=system_prompt)

        # NOTE: reward_manager expects ground_truth + data_source available in non_tensor_batch;
        # the canonical verl GSM8K preprocess stores it in reward_model. Keep that structure.
        row: Dict[str, Any] = {
            "data_source": data_source,
            "prompt": prompt,
            "ability": ability,
            "reward_model": {
                "style": "rule",
                "ground_truth": ground_truth,
            },
            "extra_info": {
                "split": split,
                "index": idx,

                # Interaction-specific payload for your Ver@K retry agent:
                # RLHFDataset will later read extra_info.interaction_kwargs
                # and expose it as row_dict["interaction_kwargs"] during rollout.
                "interaction_kwargs": {
                    "name": interaction_name,
                    # Keep 'query' consistent with what your interaction agent expects.
                    # Here we pass the same user-visible content used in prompt (question + instruction).
                    "query": prompt[-1]["content"],
                    "ground_truth": ground_truth,
                    "max_attempts": int(k_max_attempts),
                },
            },
        }
        return row

    return _process


def _maybe_limit_dataset(ds: datasets.Dataset, max_samples: int, seed: Optional[int]) -> datasets.Dataset:
    if max_samples is None or max_samples <= 0:
        return ds
    if max_samples >= len(ds):
        return ds
    if seed is not None:
        ds = ds.shuffle(seed=seed)
    return ds.select(range(max_samples))


def main():
    parser = argparse.ArgumentParser(description="Preprocess GSM8K for verl Ver@K retry interaction training.")
    parser.add_argument("--local_save_dir", type=str, default=None,
                        help="Local output directory for train.parquet and test.parquet. "
                        "Defaults to ./data/gsm8k_ver_k_retry_k<k_max_attempts>.")
    parser.add_argument("--hdfs_save_dir", type=str, default=None,
                        help="Optional HDFS output dir. If set, parquet files are copied there too.")
    parser.add_argument("--data_source", type=str, default=DEFAULT_DATA_SOURCE,
                        help="Value stored in the 'data_source' column. Keep 'openai/gsm8k' to use built-in gsm8k reward routing.")
    parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATA_SOURCE,
                        help="HF dataset name to load. For GSM8K, usually 'openai/gsm8k'.")
    parser.add_argument("--dataset_config", type=str, default=DEFAULT_DATA_CONFIG,
                        help="HF dataset config/subset. For GSM8K, usually 'main'.")
    parser.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION,
                        help="Instruction appended to each question (controls output format, e.g., \\boxed{} answer).")
    parser.add_argument("--system_prompt", type=str, default=None,
                        help="Optional system prompt. If not set, we only use a user message.")
    parser.add_argument("--interaction_name", type=str, default="ver_k_retry",
                        help="Must match the 'name' your interaction registry/config expects.")
    parser.add_argument("--k_max_attempts", type=int, default=4,
                        help="K: maximum retries/turns for your Ver@K interaction agent.")
    parser.add_argument("--num_proc", type=int, default=8,
                        help="Multiprocessing workers for datasets.map.")
    parser.add_argument("--max_train_samples", type=int, default=0,
                        help="If >0, subsample train split to this many examples.")
    parser.add_argument("--max_test_samples", type=int, default=0,
                        help="If >0, subsample test split to this many examples.")
    parser.add_argument("--seed", type=int, default=None,
                        help="Seed used only when subsampling with shuffle.")
    parser.add_argument("--keep_original_columns", action="store_true",
                        help="If set, keep original GSM8K columns (question/answer) in output parquet. Default removes them.")
    args = parser.parse_args()

    default_save_dir = f"./data/gsm8k_ver_k_retry_k{args.k_max_attempts}"
    local_save_dir = os.path.expanduser(args.local_save_dir or default_save_dir)
    os.makedirs(local_save_dir, exist_ok=True)

    if args.k_max_attempts <= 0:
        raise ValueError("--k_max_attempts must be >= 1")

    # Load GSM8K from HF
    raw = datasets.load_dataset(args.dataset_name, args.dataset_config)

    train_ds = raw["train"]
    test_ds = raw["test"]

    # Optional subsampling (useful for debugging)
    train_ds = _maybe_limit_dataset(train_ds, args.max_train_samples, args.seed)
    test_ds = _maybe_limit_dataset(test_ds, args.max_test_samples, args.seed)

    map_train = _make_map_fn(
        split="train",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
    )
    map_test = _make_map_fn(
        split="test",
        data_source=args.data_source,
        ability=DEFAULT_ABILITY,
        instruction=args.instruction,
        system_prompt=args.system_prompt,
        interaction_name=args.interaction_name,
        k_max_attempts=args.k_max_attempts,
    )

    remove_train_cols = [] if args.keep_original_columns else train_ds.column_names
    remove_test_cols = [] if args.keep_original_columns else test_ds.column_names

    train_out = train_ds.map(
        map_train,
        with_indices=True,
        num_proc=max(1, args.num_proc),
        remove_columns=remove_train_cols,
        desc="Processing GSM8K train split",
    )
    test_out = test_ds.map(
        map_test,
        with_indices=True,
        num_proc=max(1, args.num_proc),
        remove_columns=remove_test_cols,
        desc="Processing GSM8K test split",
    )

    train_path = os.path.join(local_save_dir, "train.parquet")
    test_path = os.path.join(local_save_dir, "test.parquet")

    train_out.to_parquet(train_path)
    test_out.to_parquet(test_path)

    print(f"[OK] Wrote: {train_path} ({len(train_out)} rows)")
    print(f"[OK] Wrote: {test_path} ({len(test_out)} rows)")

    # Optional HDFS copy (only if running inside env with verl + HDFS utilities)
    if args.hdfs_save_dir:
        try:
            from verl.utils.hdfs_io import copy as hdfs_copy
            from verl.utils.hdfs_io import makedirs as hdfs_makedirs
        except Exception as e:
            raise RuntimeError(
                "You set --hdfs_save_dir but verl.utils.hdfs_io could not be imported. "
                "Run inside verl repo/env or remove --hdfs_save_dir."
            ) from e

        hdfs_makedirs(args.hdfs_save_dir)
        hdfs_copy(src=train_path, dst=os.path.join(args.hdfs_save_dir, "train.parquet"))
        hdfs_copy(src=test_path, dst=os.path.join(args.hdfs_save_dir, "test.parquet"))
        print(f"[OK] Copied parquet files to HDFS dir: {args.hdfs_save_dir}")

    # Quick sanity print
    print("\nSample row (train[0]):")
    print(train_out[0])


if __name__ == "__main__":
    main()
