#!/usr/bin/env python3
"""
AIME preprocessing for verl multi-turn "Ver@K retry".

Recommended HF dataset:
  AI-MO/aimo-validation-aime (90 rows) with problem/solution/answer. 
"""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, Optional

import datasets

DEFAULT_DATA_SOURCE = "AI-MO/aimo-validation-aime"
DEFAULT_ABILITY = "math"
DEFAULT_INSTRUCTION = "Let's think step by step and put the final answer in \\boxed{}."


def _build_prompt_messages(problem: str, instruction: str, system_prompt: Optional[str] = None) -> list[dict]:
    p = (problem or "").strip()
    if not p:
        raise ValueError("Empty problem encountered.")
    # Optional: AIME-specific hint
    user_content = f"{p} {instruction} (AIME answer is an integer from 0 to 999.)".strip()
    msgs = []
    if system_prompt:
        msgs.append({"role": "system", "content": system_prompt})
    msgs.append({"role": "user", "content": user_content})
    return msgs


def _maybe_limit_dataset(ds: datasets.Dataset, max_samples: int, seed: Optional[int]) -> datasets.Dataset:
    if not max_samples or max_samples <= 0 or max_samples >= len(ds):
        return ds
    if seed is not None:
        ds = ds.shuffle(seed=seed)
    return ds.select(range(max_samples))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--local_save_dir", type=str, default=None)
    ap.add_argument("--data_source", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--dataset_name", type=str, default=DEFAULT_DATA_SOURCE)
    ap.add_argument("--instruction", type=str, default=DEFAULT_INSTRUCTION)
    ap.add_argument("--system_prompt", type=str, default=None)
    ap.add_argument("--interaction_name", type=str, default="ver_k_retry")
    ap.add_argument("--k_max_attempts", type=int, default=4)
    ap.add_argument("--num_proc", type=int, default=8)
    ap.add_argument("--max_test_samples", type=int, default=0)
    ap.add_argument("--seed", type=int, default=None)
    ap.add_argument("--keep_original_columns", action="store_true")
    args = ap.parse_args()

    out_dir = os.path.expanduser(args.local_save_dir or f"./data/aime_ver_k_retry_k{args.k_max_attempts}")
    os.makedirs(out_dir, exist_ok=True)

    raw = datasets.load_dataset(args.dataset_name)
    # This dataset is often "train" only; treat it as our evaluation set.
    split_name = "train" if "train" in raw else list(raw.keys())[0]
    ds = _maybe_limit_dataset(raw[split_name], args.max_test_samples, args.seed)

    def _process(ex: Dict[str, Any], idx: int) -> Dict[str, Any]:
        problem = ex["problem"]
        answer = str(ex["answer"]).strip()
        if not answer:
            raise ValueError("Empty answer encountered.")
        prompt = _build_prompt_messages(problem, args.instruction, system_prompt=args.system_prompt)
        return {
            "data_source": args.data_source,
            "prompt": prompt,
            "ability": DEFAULT_ABILITY,
            "reward_model": {"style": "rule", "ground_truth": answer},
            "extra_info": {
                "split": "test",
                "index": idx,
                "interaction_kwargs": {
                    "name": args.interaction_name,
                    "query": prompt[-1]["content"],
                    "ground_truth": answer,
                    "max_attempts": int(args.k_max_attempts),
                },
            },
        }

    remove_cols = [] if args.keep_original_columns else ds.column_names
    out = ds.map(_process, with_indices=True, num_proc=max(1, args.num_proc), remove_columns=remove_cols)

    test_path = os.path.join(out_dir, "test.parquet")
    out.to_parquet(test_path)
    print(f"[OK] Wrote: {test_path} ({len(out)} rows)")


if __name__ == "__main__":
    main()
