#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Qwen3-VL 动作格式转换工具（适配新的动作空间）

支持的动作类型：
- key: 按键事件
- click: 点击
- long_press: 长按
- swipe: 滑动
- type: 输入文本
- system_button: 系统按钮
- open: 打开应用
- wait: 等待
- terminate: 终止任务

用于评估时将模型输出转换为统一格式进行比较。
"""

import re
import json
import math
from typing import Dict, Any, Optional, Tuple, List, Union


# ---------------------------------------------------------------------------
# 动作类型判断
# ---------------------------------------------------------------------------
def get_qwen3_action_type(tool_call: Dict[str, Any]) -> int:
    """
    获取 Qwen3 tool_call 的动作类型编号。
    
    Returns:
        int: 动作类型编号
            1: click
            2: type
            3: swipe
            4: system_button
            5: terminate
            6: wait
            7: long_press
            8: key
            9: open
            0: 未知类型
    """
    if not tool_call or not isinstance(tool_call, dict):
        return 0
    
    action = tool_call.get("action", "")
    
    if action == "click":
        return 1
    elif action == "type":
        return 2
    elif action == "swipe":
        return 3
    elif action == "system_button":
        return 4
    elif action == "terminate":
        return 5
    elif action == "wait":
        return 6
    elif action == "long_press":
        return 7
    elif action == "key":
        return 8
    elif action == "open":
        return 9
    
    return 0


# ---------------------------------------------------------------------------
# 坐标归一化辅助函数
# ---------------------------------------------------------------------------
def normalize_coordinate(coord: List[float], image_width: int, image_height: int) -> List[float]:
    """
    将坐标归一化到 0-1000 范围。
    
    Args:
        coord: 原始坐标 [x, y]（可能是像素坐标或已归一化坐标）
        image_width: 图片宽度
        image_height: 图片高度
    
    Returns:
        归一化后的坐标 [x, y]（0-1000 范围）
    """
    if len(coord) < 2:
        return coord
    
    x, y = coord[0], coord[1]
    
    # 如果坐标已经在 0-1000 范围内，且图片尺寸合理，可能已经是归一化坐标
    # 但为了统一处理，我们总是基于图片尺寸进行归一化
    # 如果图片尺寸无效，假设坐标已经是归一化的
    if image_width > 0 and image_height > 0:
        # 归一化到 0-1000 范围
        x_norm = (x / image_width) * 1000.0
        y_norm = (y / image_height) * 1000.0
        return [x_norm, y_norm]
    else:
        # 如果图片尺寸无效，假设坐标已经是归一化的，直接返回
        return [float(x), float(y)]


def safe_normalize_coordinate(coord: Any, default_coord: Optional[List[float]] = None) -> Optional[List[float]]:
    """
    安全地检查坐标是否可解析（用于参数验证）。
    
    这个函数用于检查坐标是否存在且格式正确，不进行实际的归一化计算。
    
    Args:
        coord: 要检查的坐标（可能是列表、元组或其他类型）
        default_coord: 默认坐标（如果coord无效时返回，用于兼容性）
    
    Returns:
        如果坐标可解析，返回坐标列表；否则返回 None
    """
    if coord is None:
        return None
    
    # 尝试转换为列表
    try:
        if isinstance(coord, (list, tuple)):
            if len(coord) >= 2:
                # 尝试转换为浮点数
                x = float(coord[0])
                y = float(coord[1])
                # 检查是否为有效数值
                if math.isnan(x) or math.isnan(y) or math.isinf(x) or math.isinf(y):
                    return None
                return [x, y]
            else:
                return None
        else:
            return None
    except (ValueError, TypeError, IndexError):
        return None


# ---------------------------------------------------------------------------
# 动作匹配评估
# ---------------------------------------------------------------------------
def is_qwen3_action_type_match(pred: Dict[str, Any], gt: Dict[str, Any]) -> bool:
    """
    判断两个 Qwen3 tool_call 的动作类型是否匹配。
    """
    pred_type = get_qwen3_action_type(pred)
    gt_type = get_qwen3_action_type(gt)
    return pred_type == gt_type and pred_type != 0


def is_qwen3_action_match(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    click_threshold: float = 140.0,
    image_width: Optional[int] = None,
    image_height: Optional[int] = None
) -> bool:
    """
    判断两个 Qwen3 tool_call 是否完全匹配（类型和内容都匹配）。
    
    Args:
        pred: 预测的 tool_call
        gt: 真实的 tool_call
        click_threshold: click 动作的距离阈值（归一化后的距离，0-1000 范围）
        image_width: 图片宽度（用于坐标归一化）
        image_height: 图片高度（用于坐标归一化）
    
    Returns:
        bool: 是否匹配
    """
    # 首先检查类型是否匹配
    if not is_qwen3_action_type_match(pred, gt):
        return False
    
    pred_action = pred.get("action", "")
    gt_action = gt.get("action", "")
    
    # click: 检查坐标距离（先归一化到 0-1000 后再计算距离）
    if pred_action == "click" and gt_action == "click":
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        
        # 归一化坐标到 0-1000 范围
        if image_width and image_height:
            pred_coord_norm = normalize_coordinate(pred_coord, image_width, image_height)
            gt_coord_norm = normalize_coordinate(gt_coord, image_width, image_height)
        else:
            # 如果没有图片尺寸，假设坐标已经是归一化的
            pred_coord_norm = [float(pred_coord[0]), float(pred_coord[1])]
            gt_coord_norm = [float(gt_coord[0]), float(gt_coord[1])]
        
        # 在归一化后的坐标空间中计算距离
        distance = math.sqrt(
            (pred_coord_norm[0] - gt_coord_norm[0]) ** 2 + 
            (pred_coord_norm[1] - gt_coord_norm[1]) ** 2
        )
        return distance <= click_threshold
    
    # long_press: 检查坐标距离（先归一化到 0-1000 后再计算距离）
    elif pred_action == "long_press" and gt_action == "long_press":
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        
        # 归一化坐标到 0-1000 范围
        if image_width and image_height:
            pred_coord_norm = normalize_coordinate(pred_coord, image_width, image_height)
            gt_coord_norm = normalize_coordinate(gt_coord, image_width, image_height)
        else:
            # 如果没有图片尺寸，假设坐标已经是归一化的
            pred_coord_norm = [float(pred_coord[0]), float(pred_coord[1])]
            gt_coord_norm = [float(gt_coord[0]), float(gt_coord[1])]
        
        # 在归一化后的坐标空间中计算距离
        distance = math.sqrt(
            (pred_coord_norm[0] - gt_coord_norm[0]) ** 2 + 
            (pred_coord_norm[1] - gt_coord_norm[1]) ** 2
        )
        # 检查时间是否匹配（如果提供了 time 字段）
        pred_time = pred.get("time", None)
        gt_time = gt.get("time", None)
        time_match = True
        if pred_time is not None and gt_time is not None:
            time_match = abs(pred_time - gt_time) < 0.1  # 允许 0.1 秒的误差
        return distance <= click_threshold and time_match
    
    # type: 检查文本内容
    elif pred_action == "type" and gt_action == "type":
        pred_text = pred.get("text", "")
        gt_text = gt.get("text", "")
        return pred_text == gt_text
    
    # swipe: 检查方向是否一致
    elif pred_action == "swipe" and gt_action == "swipe":
        # 计算两个 swipe 的方向
        pred_dir = get_swipe_direction(pred)
        gt_dir = get_swipe_direction(gt)
        return pred_dir == gt_dir
    
    # system_button: 检查 button 是否一致
    elif pred_action == "system_button" and gt_action == "system_button":
        return pred.get("button", "") == gt.get("button", "")
    
    # terminate: 检查 status 是否一致
    elif pred_action == "terminate" and gt_action == "terminate":
        # 必须 status 字段匹配才算成功
        pred_status = pred.get("status", "")
        gt_status = gt.get("status", "")
        return pred_status == gt_status
    
    # wait: 检查时间是否一致（如果提供了 time 字段）
    elif pred_action == "wait" and gt_action == "wait":
        pred_time = pred.get("time", None)
        gt_time = gt.get("time", None)
        # 如果都没有 time 字段，或者 time 字段相同，则认为匹配
        if pred_time is None and gt_time is None:
            return True
        if pred_time is not None and gt_time is not None:
            return abs(pred_time - gt_time) < 0.1  # 允许 0.1 秒的误差
        return False
    
    # key: 检查按键文本是否一致
    elif pred_action == "key" and gt_action == "key":
        pred_text = pred.get("text", "")
        gt_text = gt.get("text", "")
        return pred_text == gt_text
    
    # open: 检查应用名称是否一致
    elif pred_action == "open" and gt_action == "open":
        pred_text = pred.get("text", "")
        gt_text = gt.get("text", "")
        return pred_text == gt_text
    
    return False


def get_swipe_direction(tool_call: Dict[str, Any]) -> str:
    """
    获取 swipe 动作的方向。
    
    Returns:
        "UP", "DOWN", "LEFT", "RIGHT", 或 ""
    """
    coord1 = tool_call.get("coordinate", [500, 500])
    coord2 = tool_call.get("coordinate2", [500, 500])
    
    dx = coord2[0] - coord1[0]
    dy = coord2[1] - coord1[1]
    
    if abs(dy) > abs(dx):
        return "UP" if dy < 0 else "DOWN"
    else:
        return "LEFT" if dx < 0 else "RIGHT"


def calculate_click_distance(
    pred: Dict[str, Any], 
    gt: Dict[str, Any],
    image_width: Optional[int] = None,
    image_height: Optional[int] = None
) -> Optional[float]:
    """
    计算两个 click 或 long_press 动作的坐标距离（在归一化后的坐标空间中）。
    
    Args:
        pred: 预测的 tool_call
        gt: 真实的 tool_call
        image_width: 图片宽度（用于坐标归一化）
        image_height: 图片高度（用于坐标归一化）
    
    Returns:
        归一化后的距离值（0-1000 范围），如果不是 click 或 long_press 动作返回 None
    """
    pred_action = pred.get("action", "")
    gt_action = gt.get("action", "")
    
    if pred_action in ["click", "long_press"] and gt_action in ["click", "long_press"]:
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        
        # 归一化坐标到 0-1000 范围
        if image_width and image_height:
            pred_coord_norm = normalize_coordinate(pred_coord, image_width, image_height)
            gt_coord_norm = normalize_coordinate(gt_coord, image_width, image_height)
        else:
            # 如果没有图片尺寸，假设坐标已经是归一化的
            pred_coord_norm = [float(pred_coord[0]), float(pred_coord[1])]
            gt_coord_norm = [float(gt_coord[0]), float(gt_coord[1])]
        
        # 在归一化后的坐标空间中计算距离
        return math.sqrt(
            (pred_coord_norm[0] - gt_coord_norm[0]) ** 2 + 
            (pred_coord_norm[1] - gt_coord_norm[1]) ** 2
        )
    return None


# ---------------------------------------------------------------------------
# 解析辅助函数
# ---------------------------------------------------------------------------
def parse_model_output_to_qwen3(output_text: str) -> Optional[Dict[str, Any]]:
    """
    从模型输出中解析 Qwen3 tool_call。
    
    支持两种格式：
    1. <tool_call>{"name": "mobile_use", "arguments": {...}}</tool_call>
    2. 直接的 JSON 字符串
    
    Returns:
        解析后的 arguments 字典，如果解析失败返回 None
    """
    if not output_text:
        return None
    
    try:
        # 尝试从 <tool_call> 标签中提取
        if "<tool_call>" in output_text and "</tool_call>" in output_text:
            # 使用正则表达式找到所有 tool_call 标签
            import re
            tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
            matches = re.findall(tool_call_pattern, output_text, re.DOTALL)
            
            if matches:
                # 尝试解析每个匹配的tool_call，选择第一个有效的
                for tool_call_str in matches:
                    try:
                        # 清理字符串：移除特殊字符、emoji等
                        cleaned_str = tool_call_str.strip()
                        cleaned_str = re.sub(r'[^\x20-\x7E\n\r\t]', '', cleaned_str)
                        
                        parsed = json.loads(cleaned_str)
                        
                        # 如果是 {"name": "mobile_use", "arguments": {...}} 格式
                        if "arguments" in parsed:
                            return parsed["arguments"]
                        # 如果直接是 {"action": "...", ...} 格式
                        elif "action" in parsed:
                            return parsed
                    except (json.JSONDecodeError, Exception):
                        continue
            
            # 如果正则匹配失败，尝试原始方法
            json_str = output_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            # 清理特殊字符
            json_str = re.sub(r'[^\x20-\x7E\n\r\t]', '', json_str)
            parsed = json.loads(json_str)
            
            if "arguments" in parsed:
                return parsed["arguments"]
            elif "action" in parsed:
                return parsed
            
    except (IndexError, json.JSONDecodeError, Exception):
        pass
    
    return None

