"""
Prepare the DAPO-Math-17k-Processed dataset for JustRL-style training.

Dataset: open-r1/DAPO-Math-17k-Processed (deduplicated version)
Paper: https://arxiv.org/abs/2512.16649

This script:
1. Loads the deduplicated DAPO-Math-17k-Processed dataset
2. Randomly samples 1k examples for test/validation
3. Uses the remaining ~16k examples for training
4. Converts to the format expected by SkyRL training with JustRL-style prompts
"""

import argparse
import os
import random

import datasets


def process_example(example, idx, split):
    """Process a single example into the required format for SkyRL.

    The DAPO-Math-17k-Processed dataset has:
    - 'prompt': clean problem text (string)
    - 'solution': numerical answer (string)
    - 'source_prompt': original prompt with "Answer:" format instructions
    - 'reward_model': contains ground_truth

    We need to convert to \\boxed{} format to match:
    1. JustRL paper's approach
    2. AIME environment's verification logic
    """
    # Get the clean problem text from the processed dataset
    problem_text = example["prompt"]

    # Create new prompt with JustRL-style instruction
    # Following JustRL paper: "Please reason step by step, and put your final answer within \boxed{}."
    justrl_instruction = 'Please reason step by step, and put your final answer within \\boxed{}.'
    new_prompt_content = problem_text.strip() + '\n\n' + justrl_instruction

    prompt = [
        {
            "role": "user",
            "content": new_prompt_content,
        }
    ]

    # Extract ground truth - can be from 'solution' or 'reward_model.ground_truth'
    solution = example.get("solution", "")
    reward_model = example.get("reward_model", {})
    ground_truth = reward_model.get("ground_truth", solution)
    eval_style = reward_model.get("style", "rule-lighteval/MATH_v2")

    # Get metadata
    data_source = example.get("data_source", "dapo_math")
    ability = example.get("ability", "MATH")
    extra_info = example.get("extra_info", {})

    # Keep source_prompt for reference if available
    source_prompt = example.get("source_prompt", None)

    # Construct the SkyRL-compatible format
    data = {
        "data_source": data_source,
        "prompt": prompt,
        "env_class": "aime",  # Using AIME environment for math evaluation
        "reward_model": {
            "method": "rule",
            "ground_truth": ground_truth,
            "strict_box_verify": False,  # Use normalized matching (lenient)
        },
        "extra_info": {
            "split": split,
            "index": idx,
            "ability": ability,
            "original_eval_style": eval_style,
            "original_problem": problem_text,  # Keep original problem text
            "source_prompt": source_prompt,  # Keep original DAPO prompt
            **extra_info,
        },
    }
    return data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare DAPO-Math-17k-Processed dataset for JustRL-style training"
    )
    parser.add_argument(
        "--output_dir",
        default="~/data/dapo_math",
        help="Output directory for processed dataset"
    )
    parser.add_argument(
        "--test_size",
        type=int,
        default=1000,
        help="Number of examples to reserve for test/validation (default: 1000)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for train/test split (default: 42)"
    )
    parser.add_argument(
        "--subset",
        type=str,
        default="en",
        choices=["all", "en", "cn"],
        help="Dataset subset to use: 'all' (17.4k), 'en' (14.1k English only), or 'cn' (3.28k Chinese only). Default: 'en'"
    )

    args = parser.parse_args()
    args.output_dir = os.path.expanduser(args.output_dir)

    # Set random seed for reproducibility
    random.seed(args.seed)

    data_source = "open-r1/DAPO-Math-17k-Processed"

    print(f"Loading dataset from {data_source} (subset: {args.subset})...")
    dataset = datasets.load_dataset(data_source, args.subset)

    # Get the full training dataset
    full_dataset = dataset["train"]
    print(f"Full dataset size: {len(full_dataset)} samples (subset: {args.subset})")

    # Create indices for random split
    all_indices = list(range(len(full_dataset)))
    random.shuffle(all_indices)

    # Split into test and train
    test_indices = all_indices[:args.test_size]
    train_indices = all_indices[args.test_size:]

    print(f"\nSplitting dataset:")
    print(f"  Test: {len(test_indices)} samples")
    print(f"  Train: {len(train_indices)} samples")

    # Create the splits
    test_dataset = full_dataset.select(test_indices)
    train_dataset = full_dataset.select(train_indices)

    print(f"\n" + "="*60)
    print("Processing datasets...")
    print("="*60)

    # Process datasets to SkyRL format
    train_dataset = train_dataset.map(
        lambda example, idx: process_example(example, idx, "train"),
        with_indices=True,
        desc="Processing train data"
    )
    test_dataset = test_dataset.map(
        lambda example, idx: process_example(example, idx, "test"),
        with_indices=True,
        desc="Processing test data"
    )

    # Save to parquet
    os.makedirs(args.output_dir, exist_ok=True)
    train_path = os.path.join(args.output_dir, "train.parquet")
    val_path = os.path.join(args.output_dir, "validation.parquet")

    print(f"\nSaving datasets:")
    print(f"  Train: {train_path} ({len(train_dataset)} samples)")
    print(f"  Test: {val_path} ({len(test_dataset)} samples)")

    train_dataset.to_parquet(train_path)
    test_dataset.to_parquet(val_path)

    print("\nDataset preparation complete!")
    print(f"\nFinal sizes:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    # Calculate training stats
    batch_size = 128
    n_batches = len(train_dataset) / batch_size
    steps_for_1500 = 1500
    epochs = steps_for_1500 / n_batches

    print(f"\nTraining stats (batch_size={batch_size}):")
    print(f"  Batches per epoch: {n_batches:.1f}")
    print(f"  Epochs for 1500 steps: {epochs:.2f}")
    print(f"  Steps per epoch: {n_batches:.1f}")
