
import re
import json
import math
from typing import Dict, Any, Optional, Tuple, List


def safe_normalize_coordinate(coord: Any, default: List[int] = [500, 500]) -> Optional[List[int]]:
    try:
        if coord is None:
            return None
        
        if isinstance(coord, (list, tuple)):
            if len(coord) == 0:
                return None
            
            while len(coord) > 0 and isinstance(coord[0], (list, tuple)):
                coord = coord[0]
                if len(coord) == 0:
                    return None
            
            coords = []
            for i in range(min(2, len(coord))):
                val = coord[i]
                if isinstance(val, (int, float)):
                    if math.isnan(val) or math.isinf(val):
                        return None
                    else:
                        val = max(0, min(10000, int(val)))
                        coords.append(val)
                elif isinstance(val, str):
                    try:
                        val = float(val)
                        if math.isnan(val) or math.isinf(val):
                            return None
                        else:
                            val = max(0, min(10000, int(val)))
                        coords.append(val)
                    except (ValueError, TypeError):
                        return None
                else:
                    return None
            
            if len(coords) < 2:
                return None
            
            return coords[:2]
        
        elif isinstance(coord, (int, float)):
            if math.isnan(coord) or math.isinf(coord):
                return None
            val = max(0, min(10000, int(coord)))
            return None
        
        elif isinstance(coord, str):
            try:
                parsed = json.loads(coord)
                return safe_normalize_coordinate(parsed, default)
            except (json.JSONDecodeError, ValueError):
                numbers = re.findall(r'-?\d+\.?\d*', coord)
                if len(numbers) >= 2:
                    try:
                        x = max(0, min(10000, int(float(numbers[0]))))
                        y = max(0, min(10000, int(float(numbers[1]))))
                        return [x, y]
                    except (ValueError, TypeError):
                        return None
                else:
                    return None
        
        return None
        
    except Exception:
        return None

def get_qwen3_action_type(tool_call: Dict[str, Any]) -> int:
    if not tool_call or not isinstance(tool_call, dict):
        return 0
    
    action = tool_call.get("action", "")
    
    if action == "click":
        return 1
    elif action == "type" or action == "answer":
        return 2
    elif action == "swipe":
        return 3
    elif action == "system_button":
        return 4
    elif action == "terminate":
        return 5
    elif action == "wait":
        return 6
    elif action == "long_press":
        return 7
    
    return 0


def is_qwen3_action_type_match(pred: Dict[str, Any], gt: Dict[str, Any]) -> bool:
    pred_type = get_qwen3_action_type(pred)
    gt_type = get_qwen3_action_type(gt)
    return pred_type == gt_type and pred_type != 0


def is_qwen3_action_match(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    click_threshold: float = 140.0
) -> bool:
    if not is_qwen3_action_type_match(pred, gt):
        return False
    
    pred_action = pred.get("action", "")
    gt_action = gt.get("action", "")
    
    if pred_action == "click" and gt_action == "click":
        try:
            pred_coord = pred.get("coordinate", None)
            gt_coord = gt.get("coordinate", None)
            
            if pred_coord is None or gt_coord is None:
                return False
            
            pred_coord_safe = safe_normalize_coordinate(pred_coord, [0, 0])
            gt_coord_safe = safe_normalize_coordinate(gt_coord, [0, 0])
            
            if pred_coord_safe is None or gt_coord_safe is None:
                return False
            
            if len(pred_coord_safe) < 2 or len(gt_coord_safe) < 2:
                return False
            
            dx = pred_coord_safe[0] - gt_coord_safe[0]
            dy = pred_coord_safe[1] - gt_coord_safe[1]
            distance = math.sqrt(dx * dx + dy * dy)
            
            if math.isnan(distance) or math.isinf(distance):
                return False
            
            return distance <= click_threshold
        except Exception:
            return False
    
    elif pred_action == "long_press" and gt_action == "long_press":
        try:
            pred_coord = pred.get("coordinate", None)
            gt_coord = gt.get("coordinate", None)
            
            if pred_coord is None or gt_coord is None:
                return False
            
            pred_coord_safe = safe_normalize_coordinate(pred_coord, [0, 0])
            gt_coord_safe = safe_normalize_coordinate(gt_coord, [0, 0])
            
            if pred_coord_safe is None or gt_coord_safe is None:
                return False
            
            if len(pred_coord_safe) < 2 or len(gt_coord_safe) < 2:
                return False
            
            dx = pred_coord_safe[0] - gt_coord_safe[0]
            dy = pred_coord_safe[1] - gt_coord_safe[1]
            distance = math.sqrt(dx * dx + dy * dy)
            
            if math.isnan(distance) or math.isinf(distance):
                return False
            
            return distance <= click_threshold
        except Exception:
            return False
    
    elif pred_action in ["type", "answer"] and gt_action in ["type", "answer"]:
        pred_text = pred.get("text", "")
        gt_text = gt.get("text", "")
        return pred_text == gt_text
    
    elif pred_action == "swipe" and gt_action == "swipe":
        pred_dir = get_swipe_direction(pred)
        gt_dir = get_swipe_direction(gt)
        return pred_dir == gt_dir
    
    elif pred_action == "system_button" and gt_action == "system_button":
        return pred.get("button", "") == gt.get("button", "")
    
    elif pred_action == "terminate" and gt_action == "terminate":
        pred_status = pred.get("status", "")
        gt_status = gt.get("status", "")
        return pred_status == gt_status
    
    elif pred_action == "wait" and gt_action == "wait":
        pred_time = pred.get("time", None)
        gt_time = gt.get("time", None)
        if pred_time is None and gt_time is None:
            return True
        if pred_time is not None and gt_time is not None:
            return abs(pred_time - gt_time) < 0.1
        return False
    
    return False


def get_swipe_direction(tool_call: Dict[str, Any]) -> str:
    try:
        if not tool_call or not isinstance(tool_call, dict):
            return ""
        
        coord1 = tool_call.get("coordinate", None)
        coord2 = tool_call.get("coordinate2", None)
        
        if coord1 is None or coord2 is None:
            return ""
        
        coord1_safe = safe_normalize_coordinate(coord1, [500, 500])
        coord2_safe = safe_normalize_coordinate(coord2, [500, 500])
        
        if coord1_safe is None or coord2_safe is None:
            return ""
        
        if len(coord1_safe) < 2 or len(coord2_safe) < 2:
            return ""
        
        dx = coord2_safe[0] - coord1_safe[0]
        dy = coord2_safe[1] - coord1_safe[1]
        
        if math.isnan(dx) or math.isnan(dy) or math.isinf(dx) or math.isinf(dy):
            return ""
        
        if abs(dx) < 1 and abs(dy) < 1:
            return ""
        
        if abs(dy) > abs(dx):
            return "UP" if dy < 0 else "DOWN"
        else:
            return "LEFT" if dx < 0 else "RIGHT"
    except Exception:
        return ""


def calculate_click_distance(pred: Dict[str, Any], gt: Dict[str, Any]) -> Optional[float]:
    try:
        if not pred or not isinstance(pred, dict):
            return None
        if not gt or not isinstance(gt, dict):
            return None
        
        pred_action = pred.get("action", "")
        gt_action = gt.get("action", "")
        
        if pred_action in ["click", "long_press"] and gt_action in ["click", "long_press"]:
            pred_coord = pred.get("coordinate", None)
            gt_coord = gt.get("coordinate", None)
            
            if pred_coord is None or gt_coord is None:
                return None
            
            pred_coord_safe = safe_normalize_coordinate(pred_coord, [0, 0])
            gt_coord_safe = safe_normalize_coordinate(gt_coord, [0, 0])
            
            if pred_coord_safe is None or gt_coord_safe is None:
                return None
            
            if len(pred_coord_safe) < 2 or len(gt_coord_safe) < 2:
                return None
            
            dx = pred_coord_safe[0] - gt_coord_safe[0]
            dy = pred_coord_safe[1] - gt_coord_safe[1]
            distance = math.sqrt(dx * dx + dy * dy)
            
            if math.isnan(distance) or math.isinf(distance):
                return None
            
            return distance
        
        return None
    except Exception:
        return None


def parse_model_output_to_qwen3(output_text: str) -> Optional[Dict[str, Any]]:
    if not output_text:
        return None
    
    try:
        if "<tool_call>" in output_text and "</tool_call>" in output_text:
            json_str = output_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            parsed = json.loads(json_str)
            
            if "arguments" in parsed:
                return parsed["arguments"]
            elif "action" in parsed:
                return parsed
            
    except (IndexError, json.JSONDecodeError):
        pass
    
    return None


def evaluate_qwen3_action(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    click_threshold: float = 140.0
) -> Dict[str, Any]:
    pred_type = get_qwen3_action_type(pred)
    gt_type = get_qwen3_action_type(gt)
    
    type_match = is_qwen3_action_type_match(pred, gt)
    full_match = is_qwen3_action_match(pred, gt, click_threshold)
    
    result = {
        "pred_type": pred_type,
        "gt_type": gt_type,
        "type_match": type_match,
        "full_match": full_match,
    }
    
    click_distance = calculate_click_distance(pred, gt)
    if click_distance is not None:
        result["click_distance"] = click_distance
    
    return result
