#!/usr/bin/env python3

import os
import sys
import json
import argparse
import re
from typing import Dict, Any, List, Optional
from pathlib import Path

current_dir = os.path.dirname(os.path.abspath(__file__))
preprocess_dir = current_dir
utils_dir = os.path.join(preprocess_dir, "utils")
sys.path.insert(0, utils_dir)

from qwen3_mobile_use import MOBILE_USE_TOOL_SCHEMA, build_mobile_use_system_prompt, build_user_query


def parse_tool_call_from_answer(answer: str) -> Optional[Dict[str, Any]]:
    """Parse tool_call from answer text."""
    try:
        if "<tool_call>" in answer and "</tool_call>" in answer:
            json_str = answer.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            tool_call = json.loads(json_str)
            if "arguments" in tool_call:
                return tool_call["arguments"]
    except (IndexError, json.JSONDecodeError):
        pass
    return None


def extract_thought_and_action(answer: str) -> tuple[str, str]:
    """Extract Thought and Action from answer text."""
    thought = ""
    action = ""
    
    thought_match = re.search(r'Thought:\s*(.+?)(?:\n|Action:)', answer, re.DOTALL)
    if thought_match:
        thought = thought_match.group(1).strip()
    
    action_match = re.search(r'Action:\s*(.+?)(?:\n|<tool_call>)', answer, re.DOTALL)
    if action_match:
        action = action_match.group(1).strip()
    
    return thought, action


def format_action_history(action_history: Any) -> str:
    """Format action_history to string format."""
    if not action_history:
        return ""
    
    if isinstance(action_history, str):
        return action_history
    
    if isinstance(action_history, list):
        if len(action_history) == 0:
            return ""
        formatted = []
        for idx, action in enumerate(action_history, 1):
            if isinstance(action, str):
                formatted.append(f"Step {idx}: {action}")
        return "; ".join(formatted)
    
    return ""


def build_messages_format(system_prompt: str, user_query: str, answer: str, image_path: Optional[str] = None) -> List[Dict[str, Any]]:
    """Build messages in the format for prompt_str."""
    messages = []
    
    messages.append({
        "role": "system",
        "content": [{"type": "text", "text": system_prompt}]
    })
    
    user_content = [{"type": "text", "text": user_query}]
    if image_path:
        user_content.insert(0, {"type": "image", "image": "placeholder_0"})
    
    messages.append({
        "role": "user",
        "content": user_content
    })
    
    messages.append({
        "role": "assistant",
        "content": answer
    })
    
    return messages


def convert_sample(sample: Dict[str, Any], system_prompt: str, data_type: int = 0) -> Dict[str, Any]:
    """Convert a single sample to training format."""
    new_sample = {}
    
    new_sample["images"] = sample.get("images", [])
    if not new_sample["images"] and "image_path" in sample:
        new_sample["images"] = [sample["image_path"]]
    
    new_sample["episode_id"] = str(sample.get("episode_id", ""))
    new_sample["step_id"] = sample.get("step_id", 0)
    new_sample["instruction"] = sample.get("instruction", "")
    new_sample["action_history"] = format_action_history(sample.get("action_history", ""))
    
    answer = sample.get("answer", "")
    if not answer:
        return None
    
    action_history_list = None
    action_history_str = new_sample["action_history"]
    if action_history_str:
        if isinstance(action_history_str, str) and "Step" in action_history_str:
            action_history_list = []
            steps = action_history_str.split("; ")
            for step in steps:
                step = step.strip()
                if step.startswith("Step "):
                    action = step.split(": ", 1)[1] if ": " in step else ""
                    if action:
                        action_history_list.append(action)
        elif isinstance(action_history_str, list):
            action_history_list = action_history_str
    
    user_query_text = build_user_query(new_sample["instruction"], action_history_list)
    user_query_text = user_query_text.strip()
    if not user_query_text.endswith("\nScreenshot: <image>"):
        user_query_text = f"{user_query_text}\nScreenshot: <image>"
    
    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": user_query_text
        },
        {
            "role": "assistant",
            "content": answer
        }
    ]
    
    new_sample["messages"] = messages
    
    new_sample["gt_action"] = sample.get("gt_action", {})
    
    new_sample["data_type"] = data_type
    new_sample["dataset_name"] = sample.get("dataset_name", "")
    
    new_sample["answer"] = answer
    
    image_path = new_sample["images"][0] if new_sample["images"] else None
    prompt_messages = build_messages_format(system_prompt, user_query_text, answer, image_path)
    new_sample["prompt_str"] = json.dumps(prompt_messages, ensure_ascii=False)
    
    gt_action = new_sample["gt_action"]
    if gt_action:
        new_sample["answer_str"] = json.dumps(gt_action, ensure_ascii=False)
    else:
        tool_call_args = parse_tool_call_from_answer(answer)
        if tool_call_args:
            new_sample["answer_str"] = json.dumps(tool_call_args, ensure_ascii=False)
        else:
            new_sample["answer_str"] = "{}"
    
    return new_sample


def main():
    parser = argparse.ArgumentParser(description='Convert annotated data to training format')
    parser.add_argument('--input_file', type=str, default='/INPUT_FILE',
                       help='Input annotated JSON file path')
    parser.add_argument('--output_file', type=str, default='/OUTPUT_FILE',
                       help='Output JSON file path')
    parser.add_argument('--data_type', type=int, default=0,
                       help='Data type: 0 for general data, 2 for trap data')
    
    args = parser.parse_args()
    
    system_prompt = build_mobile_use_system_prompt()
    
    print(f"Loading data from: {args.input_file}")
    with open(args.input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    print(f"Processing {len(data)} samples with data_type={args.data_type}...")
    converted_data = []
    skipped = 0
    
    for idx, sample in enumerate(data):
        converted = convert_sample(sample, system_prompt, data_type=args.data_type)
        if converted is None:
            skipped += 1
            if skipped <= 10:
                print(f"Warning: Skipping sample {idx} (missing answer field)")
            continue
        converted_data.append(converted)
    
    print(f"Converted {len(converted_data)} samples, skipped {skipped} samples")
    
    output_dir = os.path.dirname(args.output_file) if os.path.dirname(args.output_file) else "."
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Saving to: {args.output_file}")
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)
    
    print("Conversion completed!")


if __name__ == "__main__":
    main()
