#!/usr/bin/env python3
"""
Convert Extracted Bench data to verl/FileAgent format.

This script converts extracted_bench-v2.json to the parquet format expected by
recipe/fileagent/rl_dataset.py (CustomRLHFDataset).

Usage:
    python convert_extracted_bench.py \
        --input /path/to/extracted_bench-v2.json \
        --output /path/to/output.parquet \
        --system_prompt /path/to/system_prompt.md
"""

import json
import argparse
from pathlib import Path
import pandas as pd
from typing import Dict, List, Any


def load_system_prompt(prompt_path: str) -> str:
    """Load system prompt from file."""
    with open(prompt_path, 'r', encoding='utf-8') as f:
        return f.read().strip()


def convert_item(item: Dict[str, Any], system_prompt: str) -> Dict[str, Any]:
    """
    Convert a single extracted_bench item to verl format.
    
    Expected input format:
    {
        "task_id": "extracted_bench/task_123",
        "level": 2,
        "question": "...",
        "answer": "42",
        "prewrites": {"file1.txt": "content1", ...}
    }
    
    Output format (for CustomRLHFDataset):
    {
        "data_source": "extracted_bench",
        "prompt": [
            {"role": "system", "content": "<system_prompt>"},
            {"role": "user", "content": "<question>"}
        ],
        "extra_info": {
            "task_id": "...",
            "level": 2,
            "question": "...",
            "need_tools_kwargs": True,
            "tools_kwargs": {
                "global_tool": {
                    "create_kwargs": {
                        "prewrites": {...},
                        "ground_truth": "..."
                    }
                }
            }
        }
    }
    """
    task_id = item.get("task_id", "unknown")
    level = item.get("level", 0)
    # Support both "question" and "formatted_question" fields
    question = item.get("formatted_question") or item.get("question", "")
    answer = item.get("answer", "")
    
    # Handle prewrites: keep original list format to preserve fpath information
    # The tool needs fpath to read and write files to sandbox
    raw_prewrites = item.get("prewrites", [])
    if isinstance(raw_prewrites, list):
        # Keep list format with full metadata (filename, fpath, binary)
        prewrites = raw_prewrites
    elif isinstance(raw_prewrites, dict):
        # Convert dict format back to list format if needed
        prewrites = [
            {"filename": k, "content": v, "binary": False}
            for k, v in raw_prewrites.items()
        ]
    else:
        prewrites = []
    
    # Format question with markdown code blocks if it contains file content descriptions
    formatted_question = question
    
    # Build verl data item
    verl_item = {
        "data_source": "extracted_bench",
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": formatted_question}
        ],
        "reward_model": {
            "style": "model",
            "ground_truth": answer
        },
        "extra_info": {
            "task_id": task_id,
            "level": level,
            "question": formatted_question,  # Required by FileAgent reward_score.py
            "need_tools_kwargs": True,
            "tools_kwargs": {
                "global_tool": {
                    "create_kwargs": {
                        "prewrites": prewrites,
                        "ground_truth": answer
                    }
                }
            }
        },
        # ground_truth for reward computation (keep for backward compatibility)
        "ground_truth": answer
    }
    
    return verl_item


def main():
    parser = argparse.ArgumentParser(description="Convert Extracted Bench to verl format")
    parser.add_argument(
        "--input",
        type=str,
        default="/mnt/bn/fileagent-storage/users/<your_username>/verl/data/all_pdfs_short_1000_verl/dataset_original.json",
        help="Path to input dataset_original.json"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="/mnt/bn/fileagent-storage/users/<your_username>/verl/data/all_pdfs_short_1000_verl/train.parquet",
        help="Path to output parquet file"
    )
    parser.add_argument(
        "--system_prompt",
        type=str,
        default="/mnt/bn/fileagent-storage/users/<your_username>/verl/recipe/fileagent/prompts/extracted_bench_sp.md",
        help="Path to system prompt file"
    )
    parser.add_argument(
        "--max_items",
        type=int,
        default=None,
        help="Maximum number of items to convert (for testing)"
    )
    
    args = parser.parse_args()
    
    # Load system prompt
    print(f"Loading system prompt from: {args.system_prompt}")
    system_prompt = load_system_prompt(args.system_prompt)
    print(f"System prompt loaded ({len(system_prompt)} characters)")
    
    # Load input data
    print(f"\nLoading input data from: {args.input}")
    with open(args.input, 'r', encoding='utf-8') as f:
        input_data = json.load(f)
    
    total_items = len(input_data)
    print(f"Total items in input: {total_items}")
    
    # Limit items if specified
    if args.max_items:
        input_data = input_data[:args.max_items]
        print(f"Processing first {len(input_data)} items")
    
    # Convert items
    print("\nConverting items...")
    verl_data = []
    for i, item in enumerate(input_data):
        try:
            verl_item = convert_item(item, system_prompt)
            verl_data.append(verl_item)
            
            if (i + 1) % 100 == 0:
                print(f"  Processed {i + 1}/{len(input_data)} items")
        except Exception as e:
            print(f"  Error converting item {i}: {e}")
            continue
    
    print(f"Successfully converted {len(verl_data)}/{len(input_data)} items")
    
    # Create output directory if it doesn't exist
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Save as parquet
    print(f"\nSaving to: {args.output}")
    df = pd.DataFrame(verl_data)
    df.to_parquet(args.output, index=False)
    
    print(f"✅ Conversion complete!")
    print(f"   Output: {args.output}")
    print(f"   Items: {len(verl_data)}")
    
    # Print sample
    if len(verl_data) > 0:
        print(f"\n📄 Sample item:")
        sample = verl_data[0]
        print(f"   Task ID: {sample['extra_info']['task_id']}")
        print(f"   Level: {sample['extra_info']['level']}")
        print(f"   Question (first 100 chars): {sample['extra_info']['question'][:100]}...")
        print(f"   Ground truth: {sample['ground_truth']}")
        print(f"   Prewrites: {len(sample['extra_info']['tools_kwargs']['global_tool']['create_kwargs']['prewrites'])} files")


if __name__ == "__main__":
    main()

