
import re
import json
import math
from typing import Dict, Any, Optional, Tuple, List


def qwen3_to_atlas(tool_call: Dict[str, Any]) -> str:
    if not tool_call or not isinstance(tool_call, dict):
        return ""
    
    action = tool_call.get("action", "")
    
    if action == "click":
        coord = tool_call.get("coordinate", [0, 0])
        x, y = int(coord[0]), int(coord[1])
        return f"CLICK <point>[[{x}, {y}]]</point>"
    
    elif action == "long_press":
        coord = tool_call.get("coordinate", [0, 0])
        x, y = int(coord[0]), int(coord[1])
        return f"CLICK <point>[[{x}, {y}]]</point>"
    
    elif action == "type":
        text = tool_call.get("text", "")
        return f"TYPE [{text}]"
    
    elif action == "swipe":
        coord1 = tool_call.get("coordinate", [500, 500])
        coord2 = tool_call.get("coordinate2", [500, 500])
        x1, y1 = int(coord1[0]), int(coord1[1])
        x2, y2 = int(coord2[0]), int(coord2[1])
        
        dx, dy = x2 - x1, y2 - y1
        
        if abs(dy) > abs(dx):
            if dy < 0:
                return "SCROLL [UP]"
            else:
                return "SCROLL [DOWN]"
        else:
            if dx < 0:
                return "SCROLL [LEFT]"
            else:
                return "SCROLL [RIGHT]"
    
    elif action == "system_button":
        button = tool_call.get("button", "")
        if button == "Back":
            return "PRESS_BACK"
        elif button == "Home":
            return "PRESS_HOME"
        elif button == "Enter":
            return "ENTER"
        elif button == "Menu":
            return "PRESS_BACK"
        else:
            return f"PRESS_{button.upper()}"
    
    elif action == "terminate":
        return "COMPLETE"
    
    elif action == "wait":
        return ""
    
    elif action == "answer":
        text = tool_call.get("text", "")
        return f"TYPE [{text}]"
    
    return ""


def atlas_to_qwen3(action: str) -> Dict[str, Any]:
    if not action or not isinstance(action, str):
        return {}
    
    action = action.strip()
    
    click_match = re.search(r'CLICK\s+<point>\[\[(\d+),\s*(\d+)\]\]</point>', action, re.IGNORECASE)
    if click_match:
        x, y = int(click_match.group(1)), int(click_match.group(2))
        return {"action": "click", "coordinate": [x, y]}
    
    type_match = re.search(r'TYPE\s+\[(.*?)\]', action, re.IGNORECASE | re.DOTALL)
    if type_match:
        text = type_match.group(1)
        return {"action": "type", "text": text}
    
    scroll_match = re.search(r'SCROLL\s+\[(UP|DOWN|LEFT|RIGHT)\]', action, re.IGNORECASE)
    if scroll_match:
        direction = scroll_match.group(1).upper()
        if direction == "UP":
            return {"action": "swipe", "coordinate": [500, 700], "coordinate2": [500, 300]}
        elif direction == "DOWN":
            return {"action": "swipe", "coordinate": [500, 300], "coordinate2": [500, 700]}
        elif direction == "LEFT":
            return {"action": "swipe", "coordinate": [700, 500], "coordinate2": [300, 500]}
        elif direction == "RIGHT":
            return {"action": "swipe", "coordinate": [300, 500], "coordinate2": [700, 500]}
    
    if action.upper().strip() == "ENTER":
        return {"action": "system_button", "button": "Enter"}
    
    if action.upper().strip() == "PRESS_HOME":
        return {"action": "system_button", "button": "Home"}
    
    if action.upper().strip() == "PRESS_BACK":
        return {"action": "system_button", "button": "Back"}
    
    if action.upper().strip() == "COMPLETE":
        return {"action": "terminate", "status": "success"}
    
    return {}


def get_qwen3_action_type(tool_call: Dict[str, Any]) -> int:
    if not tool_call or not isinstance(tool_call, dict):
        return 0
    
    action = tool_call.get("action", "")
    
    if action == "click":
        return 1
    elif action == "type" or action == "answer":
        return 2
    elif action == "swipe":
        return 3
    elif action == "system_button":
        return 4
    elif action == "terminate":
        return 5
    elif action == "wait":
        return 6
    elif action == "long_press":
        return 7
    
    return 0


def is_qwen3_action_type_match(pred: Dict[str, Any], gt: Dict[str, Any]) -> bool:
    pred_type = get_qwen3_action_type(pred)
    gt_type = get_qwen3_action_type(gt)
    return pred_type == gt_type and pred_type != 0


def is_qwen3_action_match(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    click_threshold: float = 140.0
) -> bool:
    if not is_qwen3_action_type_match(pred, gt):
        return False
    
    pred_action = pred.get("action", "")
    gt_action = gt.get("action", "")
    
    if pred_action == "click" and gt_action == "click":
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        distance = math.sqrt(
            (pred_coord[0] - gt_coord[0]) ** 2 + 
            (pred_coord[1] - gt_coord[1]) ** 2
        )
        return distance <= click_threshold
    
    elif pred_action == "long_press" and gt_action == "long_press":
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        distance = math.sqrt(
            (pred_coord[0] - gt_coord[0]) ** 2 + 
            (pred_coord[1] - gt_coord[1]) ** 2
        )
        return distance <= click_threshold
    
    elif pred_action in ["type", "answer"] and gt_action in ["type", "answer"]:
        pred_text = pred.get("text", "")
        gt_text = gt.get("text", "")
        return pred_text == gt_text
    
    elif pred_action == "swipe" and gt_action == "swipe":
        pred_dir = get_swipe_direction(pred)
        gt_dir = get_swipe_direction(gt)
        return pred_dir == gt_dir
    
    elif pred_action == "system_button" and gt_action == "system_button":
        return pred.get("button", "") == gt.get("button", "")
    
    elif pred_action == "terminate" and gt_action == "terminate":
        pred_status = pred.get("status", "")
        gt_status = gt.get("status", "")
        return pred_status == gt_status
    
    elif pred_action == "wait" and gt_action == "wait":
        pred_time = pred.get("time", None)
        gt_time = gt.get("time", None)
        if pred_time is None and gt_time is None:
            return True
        if pred_time is not None and gt_time is not None:
            return abs(pred_time - gt_time) < 0.1
        return False
    
    return False


def get_swipe_direction(tool_call: Dict[str, Any]) -> str:
    coord1 = tool_call.get("coordinate", [500, 500])
    coord2 = tool_call.get("coordinate2", [500, 500])
    
    dx = coord2[0] - coord1[0]
    dy = coord2[1] - coord1[1]
    
    if abs(dy) > abs(dx):
        return "UP" if dy < 0 else "DOWN"
    else:
        return "LEFT" if dx < 0 else "RIGHT"


def calculate_click_distance(pred: Dict[str, Any], gt: Dict[str, Any]) -> Optional[float]:
    pred_action = pred.get("action", "")
    gt_action = gt.get("action", "")
    
    if pred_action in ["click", "long_press"] and gt_action in ["click", "long_press"]:
        pred_coord = pred.get("coordinate", [0, 0])
        gt_coord = gt.get("coordinate", [0, 0])
        return math.sqrt(
            (pred_coord[0] - gt_coord[0]) ** 2 + 
            (pred_coord[1] - gt_coord[1]) ** 2
        )
    return None


def parse_model_output_to_qwen3(output_text: str) -> Optional[Dict[str, Any]]:
    if not output_text:
        return None
    
    try:
        if "<tool_call>" in output_text and "</tool_call>" in output_text:
            json_str = output_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            parsed = json.loads(json_str)
            
            if "arguments" in parsed:
                return parsed["arguments"]
            elif "action" in parsed:
                return parsed
            
    except (IndexError, json.JSONDecodeError):
        pass
    
    return None


def evaluate_qwen3_action(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    click_threshold: float = 140.0
) -> Dict[str, Any]:
    pred_type = get_qwen3_action_type(pred)
    gt_type = get_qwen3_action_type(gt)
    
    type_match = is_qwen3_action_type_match(pred, gt)
    full_match = is_qwen3_action_match(pred, gt, click_threshold)
    
    result = {
        "pred_type": pred_type,
        "gt_type": gt_type,
        "type_match": type_match,
        "full_match": full_match,
    }
    
    click_distance = calculate_click_distance(pred, gt)
    if click_distance is not None:
        result["click_distance"] = click_distance
    
    return result


if __name__ == "__main__":
    print("=== Testing qwen3_to_atlas ===")
    test_cases = [
        {"action": "click", "coordinate": [101, 872]},
        {"action": "type", "text": "Shanghai shopping mall"},
        {"action": "swipe", "coordinate": [500, 700], "coordinate2": [500, 300]},
        {"action": "swipe", "coordinate": [500, 300], "coordinate2": [500, 700]},
        {"action": "system_button", "button": "Back"},
        {"action": "system_button", "button": "Home"},
        {"action": "system_button", "button": "Enter"},
        {"action": "terminate", "status": "success"},
    ]
    
    for tc in test_cases:
        result = qwen3_to_atlas(tc)
        print(f"{tc} -> {result}")
    
    print("\n=== Testing atlas_to_qwen3 ===")
    atlas_cases = [
        "CLICK <point>[[101, 872]]</point>",
        "TYPE [Shanghai shopping mall]",
        "SCROLL [UP]",
        "SCROLL [DOWN]",
        "PRESS_BACK",
        "PRESS_HOME",
        "ENTER",
        "COMPLETE",
    ]
    
    for ac in atlas_cases:
        result = atlas_to_qwen3(ac)
        print(f"{ac} -> {result}")
    
    print("\n=== Testing action matching ===")
    pred = {"action": "click", "coordinate": [100, 100]}
    gt = {"action": "click", "coordinate": [120, 110]}
    print(f"pred={pred}, gt={gt}")
    print(f"  type_match: {is_qwen3_action_type_match(pred, gt)}")
    print(f"  full_match: {is_qwen3_action_match(pred, gt)}")
    print(f"  distance: {calculate_click_distance(pred, gt)}")
    
    print("\nAll tests passed!")

