import json
import os
import re
import numpy as np
from typing import List, Dict, Tuple, Optional, Any

# 动作等价映射表
ACTION_EQUIVALENCE_MAP = {
    "scroll": ["swipe", "drag", "scroll_up", "scroll_down", "scroll_left", "scroll_right"],
    "swipe": ["scroll", "drag", "scroll_up", "scroll_down", "scroll_left", "scroll_right"],
    "drag": ["swipe", "scroll"],
    "navigate_back": ["press_key", "system_button"],
    "press_key": ["navigate_back"]
}

# 坐标类action
COORDINATE_ACTIONS = {"click", "long_press", "swipe", "scroll", "drag"}


def evaluate_action_performance(predictions_path: str,
                                ground_truth_episode: Dict[str, Any],
                                output_dir: Optional[str] = None) -> Dict[str, Any]:
    try:
        with open(predictions_path, 'r', encoding='utf-8') as f:
            predictions = json.load(f)

        gt_actions = ground_truth_episode["actions"]
        step_instructions = ground_truth_episode.get("step_instructions", [])
        width = ground_truth_episode.get("width", 1080)
        height = ground_truth_episode.get("height", 2400)
        step_check_pams = ground_truth_episode.get("step_check_pams", [])
    except Exception as e:
        print(f"加载数据失败: {e}")
        return {
            "error": f"数据加载失败: {e}",
            # 三核心指标
            "type_match_acc": 0.0,
            "grounding_accuracy": 0.0,
            "step_success_rate": 0.0,
            # 动作类型指标
            "click_match_acc": 0.0,
            "swipe_match_acc": 0.0,
            "input_text_match_acc": 0.0,
            "open_app_match_acc": 0.0,
            "wait_match_acc": 0.0,
            "system_button_match_acc": 0.0,
            "long_press_match_acc": 0.0,
            "scroll_match_acc": 0.0,  # 新增
            "drag_match_acc": 0.0,  # 新增
            "navigate_back_match_acc": 0.0,  # 新增
            "total_steps": 0,
            "action_type_stats": {},
            "scene_success": False
        }

    # 2. 初始化=
    type_match_count = 0  # Type
    grounding_correct_count = 0  # GR
    step_success_count = 0  # SR

    # 动作类型统计变量
    click_match_count = 0
    swipe_match_count = 0
    input_text_match_count = 0
    open_app_match_count = 0
    wait_match_count = 0
    system_button_match_count = 0
    long_press_match_count = 0
    scroll_match_count = 0
    drag_match_count = 0
    navigate_back_match_count = 0

    total_actions = min(len(predictions), len(gt_actions))
    results = []

    # 扩展动作类型
    action_type_stats = {
        "click": {"count": 0, "exact_match": 0},
        "swipe": {"count": 0, "exact_match": 0},
        "input_text": {"count": 0, "exact_match": 0},
        "open_app": {"count": 0, "exact_match": 0},
        "wait": {"count": 0, "exact_match": 0},
        "system_button": {"count": 0, "exact_match": 0},
        "long_press": {"count": 0, "exact_match": 0},
        "navigate_back": {"count": 0, "exact_match": 0},
        "scroll": {"count": 0, "exact_match": 0},
        "drag": {"count": 0, "exact_match": 0}
    }

    scene_success = True  # 初始假设整个场景成功

    # 3. 处理每个步骤
    for i in range(total_actions):
        step_result = {"step": i + 1}

        pred = predictions[i]
        pred_action = pred["action_type"]  # 预测动作类型

        try:
            gt_action = gt_actions[i]
            gt_action_type = gt_action["action_type"]
            gt_direction = gt_action.get("direction", "")  # 获取真实动作方向
            gt_button = gt_action.get("button_name", "")  # 获取按钮名称

            step_result["gt_action"] = gt_action_type
            step_result["pred_action"] = pred_action

            if gt_action_type in action_type_stats:
                action_type_stats[gt_action_type]["count"] += 1

            # 语义等价scroll <-> swipe
            is_equivalent = False
            if gt_action_type in ACTION_EQUIVALENCE_MAP:
                if pred_action in ACTION_EQUIVALENCE_MAP[gt_action_type]:
                    is_equivalent = True

            type_match = (gt_action_type == pred_action) or is_equivalent  # 等价判断

            # A. type_match检查
            if type_match:
                type_match_count += 1

                # B. ground检查
                grounding_ok = False
                step_check = step_check_pams[i] if i < len(step_check_pams) else None

                # 滚动/滑动动作的方向验证
                if gt_action_type in ["swipe", "scroll", "drag"]:
                    if "resolved_coords" in pred and "start" in pred["resolved_coords"] and "end" in pred[
                        "resolved_coords"]:
                        pred_start = pred["resolved_coords"]["start"]
                        pred_end = pred["resolved_coords"]["end"]

                        start_ok = False
                        if step_check and "candidate_bbox" in step_check:
                            bboxes = step_check["candidate_bbox"]
                            start_ok = any(point_in_bbox(pred_start[0], pred_start[1], bbox) for bbox in bboxes)
                        else:
                            gt_x, gt_y = gt_action.get("x", 0), gt_action.get("y", 0)
                            start_ok = (abs(pred_start[0] - gt_x) < 50 and abs(pred_start[1] - gt_y) < 50)

                        # 2. 方向一致性验证
                        direction_ok = True  # 默认无方向要求则通过
                        if gt_direction:
                            dx = pred_end[0] - pred_start[0]  # 水平方向
                            dy = pred_end[1] - pred_start[1]  # 垂直方向

                            # 根据向量计算实际方向
                            actual_direction = ""
                            if abs(dx) > abs(dy):  # 水平滑动为主
                                actual_direction = "right" if dx > 0 else "left"
                            else:  # 垂直滑动为主
                                actual_direction = "down" if dy > 0 else "up"

                            # 检查预测方向与真实方向是否匹配
                            direction_ok = (actual_direction == gt_direction.lower())

                        grounding_ok = start_ok and direction_ok

                elif gt_action_type in ["click", "long_press"]:
                    if "resolved_coords" in pred and "pos" in pred["resolved_coords"]:
                        pred_x, pred_y = pred["resolved_coords"]["pos"]
                        gt_x, gt_y = gt_action.get("x", 0), gt_action.get("y", 0)

                        if step_check and "candidate_bbox" in step_check:
                            bboxes = step_check["candidate_bbox"]
                            grounding_ok = any(point_in_bbox(pred_x, pred_y, bbox) for bbox in bboxes)
                        else:
                            # 默认50像素阈值
                            gt_x, gt_y = gt_action.get("x", 0), gt_action.get("y", 0)
                            grounding_ok = (abs(pred_x - gt_x) < 200 and abs(pred_y - gt_y) < 200)

                # 对于非坐标类动作，只要类型匹配就算定位正确
                else:
                    grounding_ok = True
            else:
                grounding_ok = False  # 类型不匹配则定位失败

            if grounding_ok:
                grounding_correct_count += 1

            # C. 步骤成功判定 (SR指标分子)
            # 对于坐标类动作：需要类型和定位都正确
            if gt_action_type in COORDINATE_ACTIONS:
                step_success = type_match and grounding_ok
            # 对于非坐标类动作：只需要类型正确
            else:
                step_success = type_match

            if step_success:
                step_success_count += 1

            # D. 精确匹配检查（各动作类型的独立统计）
            exact_match = False

            if gt_action_type == "click":
                if type_match and grounding_ok:
                    exact_match = True
                    click_match_count += 1

            elif gt_action_type == "swipe":
                if type_match and grounding_ok:
                    exact_match = True
                    swipe_match_count += 1

            elif gt_action_type == "input_text":
                if type_match:
                    # 检查文本内容匹配
                    pred_text = pred.get("action_params", {}).get("text", "")
                    gt_text = gt_action.get("text", "")
                    exact_match = (pred_text.lower() == gt_text.lower())
                    if exact_match:
                        input_text_match_count += 1

            elif gt_action_type == "open_app":
                if type_match:
                    # 检查应用名称匹配
                    pred_app = pred.get("action_params", {}).get("app_name", "")
                    gt_app = gt_action.get("app_name", "")
                    exact_match = (pred_app.lower() == gt_app.lower())
                    if exact_match:
                        open_app_match_count += 1

            elif gt_action_type == "wait":
                if type_match:
                    # 等待动作只需类型匹配
                    exact_match = True
                    wait_match_count += 1

            elif gt_action_type == "system_button":
                if type_match:
                    # 检查按钮名称匹配
                    pred_button = pred.get("action_params", {}).get("button_name", "")
                    gt_button = gt_action.get("button_name", "")
                    exact_match = (pred_button.lower() == gt_button.lower())
                    if exact_match:
                        system_button_match_count += 1

            elif gt_action_type == "long_press":
                if type_match and grounding_ok:
                    exact_match = True
                    long_press_match_count += 1

            elif gt_action_type == "scroll":
                if type_match and grounding_ok:  # grounding_ok已包含方向验证
                    exact_match = True
                    scroll_match_count += 1

            elif gt_action_type == "drag":
                if type_match and grounding_ok:
                    exact_match = True
                    drag_match_count += 1

            elif gt_action_type == "navigate_back":
                if type_match:
                    # 检查是否为返回操作
                    pred_button = pred.get("action_params", {}).get("button_name", "").lower()
                    exact_match = ("back" in pred_button or
                                   "navigate_back" in pred_action.lower() or
                                   "press_key" in pred_action.lower())
                    if exact_match:
                        navigate_back_match_count += 1

            step_result["exact_match"] = exact_match
            if exact_match and gt_action_type in action_type_stats:
                action_type_stats[gt_action_type]["exact_match"] += 1

            step_result["type_match"] = type_match
            step_result["grounding_ok"] = grounding_ok
            step_result["step_success"] = step_success

            if not step_success:
                scene_success = False

        except Exception as e:
            print(f"处理步骤 {i + 1} 时出错: {e}")
            step_result["error"] = str(e)
            step_result["type_match"] = False
            step_result["grounding_ok"] = False
            step_result["step_success"] = False
            step_result["exact_match"] = False
            scene_success = False

        results.append(step_result)

    # 4. 计算评估指标
    metrics = {
        # 三核心指标
        "type_match_acc": safe_divide(type_match_count, total_actions) * 100,
        "grounding_accuracy": safe_divide(grounding_correct_count, total_actions) * 100,
        "step_success_rate": safe_divide(step_success_count, total_actions) * 100,

        # 各动作类型匹配率
        "click_match_acc": safe_divide(click_match_count, action_type_stats["click"]["count"]) * 100,
        "swipe_match_acc": safe_divide(swipe_match_count, action_type_stats["swipe"]["count"]) * 100,
        "input_text_match_acc": safe_divide(input_text_match_count, action_type_stats["input_text"]["count"]) * 100,
        "open_app_match_acc": safe_divide(open_app_match_count, action_type_stats["open_app"]["count"]) * 100,
        "wait_match_acc": safe_divide(wait_match_count, action_type_stats["wait"]["count"]) * 100,
        "system_button_match_acc": safe_divide(system_button_match_count,
                                               action_type_stats["system_button"]["count"]) * 100,
        "long_press_match_acc": safe_divide(long_press_match_count, action_type_stats["long_press"]["count"]) * 100,
        "scroll_match_acc": safe_divide(scroll_match_count, action_type_stats["scroll"]["count"]) * 100,  # 新增
        "drag_match_acc": safe_divide(drag_match_count, action_type_stats["drag"]["count"]) * 100,  # 新增
        "navigate_back_match_acc": safe_divide(navigate_back_match_count,
                                               action_type_stats["navigate_back"]["count"]) * 100,  # 新增

        # 其他指标
        "total_steps": total_actions,
        "action_type_stats": action_type_stats,
        "scene_success": scene_success  # ESR指标
    }

    # 为每个动作类型计算精度
    for action_type, stats in action_type_stats.items():
        stats["accuracy"] = safe_divide(stats["exact_match"], stats["count"]) * 100

    # 5. 保存评估结果
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        eval_file = os.path.join(output_dir, "evaluation_results.json")
        with open(eval_file, 'w', encoding='utf-8') as f:
            json.dump({
                "metrics": metrics,
                "detailed_results": results,
                "predictions_path": predictions_path,
                "ground_truth_episode": ground_truth_episode
            }, f, indent=2, ensure_ascii=False)
        print(f"评估结果已保存至: {eval_file}")

    return metrics


def safe_divide(numerator, denominator):
    """安全除法，避免除以零"""
    return numerator / denominator if denominator != 0 else 0.0


def point_in_bbox(x: float, y: float, bbox: List[float]) -> bool:
    """检查点是否在边界框内（增强健壮性）"""

    if not bbox or len(bbox) < 4:
        return False
    try:
        # 2. 类型转换和坐标提取
        coords = [float(c) for c in bbox[:4]]
        x1, y1, x2, y2 = coords
    except (ValueError, TypeError):
        return False
    x_min, x_max = min(x1, x2), max(x1, x2)
    y_min, y_max = min(y1, y2), max(y1, y2)

    # 4. 点是否在框内
    return (x_min <= x <= x_max) and (y_min <= y <= y_max)