#!/usr/bin/env python3
"""
Convert JSONL-format training data into the parquet format required by GRPO training.

Input format:
{
    "conversations": [
        {"from": "system", "value": "..."},
        {"from": "human", "value": "..."},
        {"from": "gpt", "value": "..."}
    ],
    "images": ["path/to/image.png"]
}

Output format (parquet):
{
    "prompt": [
        {"role": "system", "content": "..."},
        {"role": "user", "content": "..."}
    ],
    "images": ["path/to/image.png"],
    "data_source": "...",
    "extra_info": {"index": 0}
}
"""

import json
import re
import sys
from pathlib import Path
from typing import List, Dict, Any

try:
    import datasets
    HAS_DATASETS = True
except ImportError:
    try:
        import pandas as pd
        HAS_DATASETS = False
        HAS_PANDAS = True
    except ImportError:
        print("Error: you need to install the 'datasets' or 'pandas' library")
        print("Install command: pip install datasets OR pip install pandas pyarrow")
        sys.exit(1)


def convert_conversations_to_prompt(conversations: List[Dict[str, str]]) -> List[Dict[str, str]]:
    """
    Convert the `conversations` format into the `prompt` format.

    Args:
        conversations: Original conversations list, formatted as
            [{"from": "system", "value": "..."}, ...]

    Returns:
        A prompt list formatted as [{"role": "system", "content": "..."}, ...]
    """
    prompt = []
    role_mapping = {
        "system": "system",
        "human": "user",
        "gpt": "assistant"
    }
    
    for conv in conversations:
        from_role = conv.get("from", "").lower()
        value = conv.get("value", "")
        
        # Only keep system and user/human messages as the prompt (exclude assistant responses)
        if from_role in ["system", "human"]:
            mapped_role = role_mapping.get(from_role, from_role)
            prompt.append({
                "role": mapped_role,
                "content": value
            })
    
    return prompt


def convert_jsonl_to_grpo_format(
    input_file: str,
    output_file: str,
    data_source: str = "geometry_instantiation"
):
    """
    Convert a JSONL file into the parquet format required for GRPO training.

    Args:
        input_file: Path to the input JSONL file.
        output_file: Path to the output parquet file.
        data_source: Identifier string for the data source.
    """
    converted_data = []
    skipped_count = 0
    
    print(f"Reading file: {input_file}")
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            
            try:
                data = json.loads(line)
                
                # Check required fields
                if "conversations" not in data:
                    print(f"Warning: Line {line_num} is missing 'conversations' field, skipping")
                    skipped_count += 1
                    continue
                
                conversations = data["conversations"]
                if not isinstance(conversations, list) or len(conversations) == 0:
                    print(f"Warning: Line {line_num} has empty or invalid 'conversations', skipping")
                    skipped_count += 1
                    continue
                
                # Convert to prompt format
                prompt = convert_conversations_to_prompt(conversations)
                
                if len(prompt) == 0:
                    print(f"Warning: Line {line_num} has no valid prompt messages, skipping")
                    skipped_count += 1
                    continue
                
                # Extract ground_truth from the GPT reply (between <answer> tags)
                ground_truth = None
                for conv in conversations:
                    if conv.get("from") == "gpt":
                        gpt_value = conv.get("value", "")
                        # Extract content inside <answer>...</answer> tags
                        pattern = r"<answer>(.*?)</answer>"
                        matches = re.findall(pattern, gpt_value, re.DOTALL | re.IGNORECASE)
                        if matches:
                            ground_truth = matches[-1].strip()
                            break
                
                if ground_truth is None:
                    print(f"Warning: Line {line_num} could not extract ground_truth, skipping")
                    skipped_count += 1
                    continue
                
                # Build output record
                output_data = {
                    "prompt": prompt,
                    "data_source": data_source,
                    "reward_model": {
                        "ground_truth": ground_truth
                    },
                    "extra_info": {
                        "index": line_num - 1,
                        "original_index": data.get("index", line_num - 1)
                    }
                }
                
                # If images exist, convert them into the expected format.
                # verl expected format: [{"image_url": "file:///path/to/image.png"}, ...]
                if "images" in data and data["images"]:
                    images_list = data["images"]
                    if isinstance(images_list, list) and len(images_list) > 0:
                        # Convert string paths into dict format
                        formatted_images = []
                        for img_path in images_list:
                            if isinstance(img_path, str):
                                # Convert into file:// format
                                if not img_path.startswith("file://"):
                                    formatted_images.append({"image_url": f"file://{img_path}"})
                                else:
                                    formatted_images.append({"image_url": img_path})
                            elif isinstance(img_path, dict):
                                # Already in dict format, keep as is
                                formatted_images.append(img_path)
                            else:
                                formatted_images.append({"image_url": str(img_path)})
                        output_data["images"] = formatted_images
                    else:
                        output_data["images"] = images_list
                
                # Keep other potentially useful fields (optional)
                if "ability" in data:
                    output_data["ability"] = data["ability"]
                
                converted_data.append(output_data)
                
            except json.JSONDecodeError as e:
                print(f"Error: JSON parse failed on line {line_num}: {e}")
                skipped_count += 1
                continue
            except Exception as e:
                print(f"Error: processing failed on line {line_num}: {e}")
                skipped_count += 1
                continue
    
    if len(converted_data) == 0:
        print("Error: No records were successfully converted!")
        sys.exit(1)
    
    print(f"\nConversion summary:")
    print(f"  Successfully converted: {len(converted_data)} records")
    print(f"  Skipped: {skipped_count} records")
    
    # Ensure output directory exists
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Save as parquet
    if HAS_DATASETS:
        print(f"\nCreating Dataset and saving as parquet...")
        dataset = datasets.Dataset.from_list(converted_data)
        dataset.to_parquet(output_file)
        print(f"✓ Successfully saved to: {output_file}")
        print(f"✓ Total records: {len(dataset)}")
        
        # Show an example record
        if len(dataset) > 0:
            print(f"\nExample record (first item):")
            example = dataset[0]
            print(f"  prompt: {example['prompt']}")
            if "images" in example:
                print(f"  images: {example['images']}")
            print(f"  data_source: {example['data_source']}")
    else:
        # Use pandas
        print(f"\nSaving as parquet using pandas...")
        df = pd.DataFrame(converted_data)
        df.to_parquet(output_file, index=False)
        print(f"✓ Successfully saved to: {output_file}")
        print(f"✓ Total records: {len(df)}")
        
        # Show an example record
        if len(df) > 0:
            print(f"\nExample record (first item):")
            example = df.iloc[0]
            print(f"  prompt: {example['prompt']}")
            if "images" in example:
                print(f"  images: {example['images']}")
            print(f"  data_source: {example['data_source']}")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python convert_to_grpo_parquet.py <input_jsonl> [output_parquet] [data_source]")
        print("\nArguments:")
        print("  input_jsonl:    Path to the input JSONL file")
        print("  output_parquet: Path to the output parquet file (optional, defaults to <input>.parquet)")
        print("  data_source:    Data source identifier (optional, default: 'geometry_instantiation')")
        print("\nExamples:")
        print("  python convert_to_grpo_parquet.py train.jsonl train.parquet")
        print("  python convert_to_grpo_parquet.py train.jsonl")
        sys.exit(1)
    
    input_file = sys.argv[1]
    
    if len(sys.argv) >= 3:
        output_file = sys.argv[2]
    else:
        # Default: output to the same directory and change the filename to .parquet
        input_path = Path(input_file)
        output_file = str(input_path.parent / f"{input_path.stem}.parquet")
    
    data_source = sys.argv[3] if len(sys.argv) >= 4 else "geometry_instantiation"
    
    if not Path(input_file).exists():
        print(f"Error: input file does not exist: {input_file}")
        sys.exit(1)
    
    convert_jsonl_to_grpo_format(input_file, output_file, data_source)

