#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os

import datasets
from jinja2 import Template

# Constants
PROMPT_TEMPLATE_UC = """
Please reason step by step. If confident based on reliable knowledge, provide a clear answer and box it with \\boxed{}.
If the question lacks clarity, exceeds your knowledge, involves speculation, prediction, opinion, or any uncertainty, do not guess. State your limitation and output \\boxed{uncertainty}.
"""

def get_map_functions(split_name):
    """
    Returns a processing function for dataset mapping.
    
    Args:
        split_name (str): The dataset split (e.g., 'train', 'test').
        
    Returns:
        function: A function compatible with datasets.map.
    """
    def process_item(example, idx):
        # Extract core fields with default fallbacks
        raw_prompt = example.get('prompt', '')
        label = example.get('label')
        data_source_type = example.get('data_type', 'math_dapo')

        return {
            "data_source": data_source_type,
            "prompt": [
                {
                    "role": "user",
                    "content": f"{raw_prompt}\n{PROMPT_TEMPLATE_UC}",
                }
            ],
            "prompt_text": raw_prompt,
            "label": label,
            "reward_model": {
                "style": "orm",
                "ground_truth": label
            },
            "extra_info": {
                'split': split_name,
                'index': idx,
            }
        }

    return process_item


def main():
    # Setup argument parser
    parser = argparse.ArgumentParser(description="Process JSON dataset to Parquet for UC format.")
    parser.add_argument(
        '--local_file',
        type=str,
        default='math_dapo_17k_processed.json',
        help='Path to the input JSON file'
    )
    parser.add_argument(
        '--hdfs_dir', 
        type=str, 
        default=None, 
        help='Optional HDFS directory path'
    )
    args = parser.parse_args()

    # Verify file existence
    if not os.path.exists(args.local_file):
        print(f"Error: File {args.local_file} not found.")
        return

    # Load dataset
    print(f"Loading dataset from: {args.local_file}")
    dataset = datasets.load_dataset('json', data_files=args.local_file, split='train')

    # Apply transformation
    print("Mapping dataset transformations...")
    processed_dataset = dataset.map(
        get_map_functions('train'),
        with_indices=True,
        remove_columns=dataset.column_names  # Cleans up old schema
    )

    # Save to Parquet
    output_path = args.local_file.replace('.json', '_uc.parquet')
    processed_dataset.to_parquet(output_path)
    print(f"Successfully saved processed data to: {output_path}")


if __name__ == '__main__':
    main()