"""
Prepare the MATH (hendrycks/competition_math) dataset for training.

The dataset contains 5 levels (Level 1-5).

This script filters the dataset by difficulty level and splits into train/test sets.
By default, only Level 1-3 problems are included for easier training.
"""

import argparse
import os
from collections import defaultdict

import datasets


def show_level_distribution(dataset, name="Dataset"):
    """Print the distribution of levels in the dataset."""
    level_counts = defaultdict(int)

    for example in dataset:
        level = example["level"]
        # Handle both integer levels and "Level X" strings
        if isinstance(level, int):
            level_key = f"Level {level}"
        else:
            level_key = level.strip()
        level_counts[level_key] += 1

    print(f"\n{name} distribution:")
    all_levels = sorted(level_counts.keys())
    for level in all_levels:
        print(f"  {level}: {level_counts[level]} samples")
    print(f"  Total: {len(dataset)} samples")


def extract_boxed_answer(solution: str) -> str:
    """Extract the answer from within \\boxed{} in the solution.

    Args:
        solution: The full solution text containing \\boxed{answer}

    Returns:
        The content inside the last \\boxed{}, or the full solution if no \\boxed{} found
    """
    idx = solution.rfind("\\boxed{")
    if idx < 0:
        # No boxed answer found, return the solution as-is
        return solution

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0

    while i < len(solution):
        if solution[i] == "{":
            num_left_braces_open += 1
        if solution[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        return solution

    # Extract content inside \boxed{}
    boxed_content = solution[idx + len("\\boxed{"):right_brace_idx]
    return boxed_content


def process_example(example, idx, split, data_source):
    """Process a single example into the required format."""
    problem = example["problem"]
    solution = example["solution"]
    level = example["level"]

    # nlile/hendrycks-MATH-benchmark has 'subject' field, older datasets have 'type'
    subject = example.get("subject", example.get("type", "unknown"))

    # nlile/hendrycks-MATH-benchmark already has extracted 'answer' field
    if "answer" in example:
        ground_truth_answer = example["answer"]
    else:
        # Fallback: extract from boxed content in solution
        ground_truth_answer = extract_boxed_answer(solution)

    # The instruction following format
    instruction_following = 'Solve the problem step by step and box the final answer using \\boxed{}.'

    question = problem + " " + instruction_following

    data = {
        "data_source": data_source,
        "prompt": [
            {
                "role": "user",
                "content": question,
            }
        ],
        "env_class": "aime",  # Using AIME environment for Competition MATH evaluation
        "reward_model": {
            "method": "rule",
            "ground_truth": ground_truth_answer,
            "strict_box_verify": False,  # Use Minerva normalization for more lenient answer matching
        },
        "extra_info": {
            "split": split,
            "index": idx,
            "problem": problem,
            "solution": solution,
            "level": level,
            "subject": subject,
        },
    }
    return data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare MATH dataset for training"
    )
    parser.add_argument(
        "--output_dir",
        default="~/data/competition_math",
        help="Output directory for processed dataset"
    )
    parser.add_argument(
        "--max_level",
        type=int,
        default=3,
        help="Maximum difficulty level to include (1-5). Default: 3 (includes Level 1, 2, and 3)"
    )

    args = parser.parse_args()
    args.output_dir = os.path.expanduser(args.output_dir)

    # Use nlile/hendrycks-MATH-benchmark which has explicit train/test splits
    # This avoids data contamination from shuffling
    data_source = "nlile/hendrycks-MATH-benchmark"

    print(f"Loading dataset from {data_source}...")
    print("Using explicit train/test splits to avoid contamination")
    dataset = datasets.load_dataset(data_source)

    # Filter by difficulty level for both train and test splits
    print(f"\nFiltering for Level 1-{args.max_level} problems only...")

    train_full = dataset["train"]
    test_full = dataset["test"]

    print(f"Original sizes: train={len(train_full)}, test={len(test_full)}")

    # Filter to keep only problems up to max_level
    # nlile/hendrycks-MATH-benchmark uses integer levels (1-5), not "Level X" strings
    train_dataset = train_full.filter(
        lambda example: 1 <= example["level"] <= args.max_level
    )
    test_dataset = test_full.filter(
        lambda example: 1 <= example["level"] <= args.max_level
    )

    print(f"Filtered sizes: train={len(train_dataset)}, test={len(test_dataset)}")

    print(f"\nDataset sizes after split:")
    print(f"  Train: {len(train_dataset)}")
    print(f"  Test: {len(test_dataset)}")

    # Show level distribution
    print("\n" + "="*60)
    show_level_distribution(train_dataset, "TRAIN")
    show_level_distribution(test_dataset, "TEST")
    print("="*60)

    print(f"\n" + "="*60)
    print("Processing datasets...")
    print("="*60)

    # Process datasets
    train_dataset = train_dataset.map(
        lambda example, idx: process_example(example, idx, "train", data_source),
        with_indices=True
    )
    test_dataset = test_dataset.map(
        lambda example, idx: process_example(example, idx, "test", data_source),
        with_indices=True
    )

    # Save to parquet
    os.makedirs(args.output_dir, exist_ok=True)
    train_path = os.path.join(args.output_dir, "train.parquet")
    val_path = os.path.join(args.output_dir, "validation.parquet")

    print(f"\nSaving datasets:")
    print(f"  Train: {train_path} ({len(train_dataset)} samples)")
    print(f"  Test: {val_path} ({len(test_dataset)} samples)")

    train_dataset.to_parquet(train_path)
    test_dataset.to_parquet(val_path)

    print("\nDataset preparation complete!")
    print(f"\nFinal sizes:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")
