# Copyright 2025 Xiaomi Corporation.

import re
import json
import os
import sys
import math
from typing import Dict, Any, List, Optional, Tuple
import numpy as np

from qwen3_action_mapper import (
    parse_model_output_to_qwen3,
    is_qwen3_action_match,
    is_qwen3_action_type_match,
    get_swipe_direction,
    calculate_click_distance,
    safe_normalize_coordinate,
)

REWARD_NAME = "faithful_grpo"
REWARD_TYPE = "batch"


def parse_output_text(output_text: str) -> Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
    if not output_text or not isinstance(output_text, str):
        return None, None, None
    
    thought = None
    action = None
    tool_call = None
    
    thought_pattern = r'Thought:\s*(.+?)(?=\nAction:|\n<tool_call>|\Z)'
    thought_match = re.search(thought_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if thought_match:
        thought = thought_match.group(1).strip()
    
    action_pattern = r'Action:\s*(.+?)(?=\n<tool_call>|\Z)'
    action_match = re.search(action_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if action_match:
        action = action_match.group(1).strip()
    
    tool_call_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
    tool_call_match = re.search(tool_call_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if tool_call_match:
        try:
            tool_call_json = tool_call_match.group(1)
            tool_call = json.loads(tool_call_json)
            if isinstance(tool_call, dict) and "arguments" in tool_call:
                tool_call = tool_call["arguments"]
            elif isinstance(tool_call, dict) and "name" in tool_call:
                if "arguments" in tool_call:
                    tool_call = tool_call["arguments"]
        except json.JSONDecodeError:
            pass
    
    return thought, action, tool_call


def check_format(output_text: str) -> bool:
    if not output_text or not isinstance(output_text, str):
        return False
    
    has_thought = bool(re.search(r'Thought:\s*', output_text, re.IGNORECASE))
    
    has_action = bool(re.search(r'Action:\s*', output_text, re.IGNORECASE))
    
    has_tool_call = bool(re.search(r'<tool_call>', output_text, re.IGNORECASE))
    
    return has_thought and has_action and has_tool_call



def normalized_edit_distance(text1: str, text2: str) -> float:
    if not text1 and not text2:
        return 0.0
    if not text1 or not text2:
        return 1.0
    
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i-1] == text2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(
                    dp[i-1][j] + 1,
                    dp[i][j-1] + 1,
                    dp[i-1][j-1] + 1
                )
    
    edit_dist = dp[m][n]
    max_len = max(m, n)
    
    return edit_dist / max_len if max_len > 0 else 0.0


def compute_click_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any], tau: float = 60.0) -> float:
    try:
        distance = calculate_click_distance(pred, gt)
        if distance is None:
            return 0.0
        
        if tau <= 0:
            tau = 60.0
        r_arg = math.exp(-distance / tau)
        return max(0.0, min(1.0, r_arg))
    except Exception:
        return 0.0


def compute_type_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    pred_text = pred.get("text", "")
    gt_text = gt.get("text", "")
    
    if not pred_text and not gt_text:
        return 1.0
    if not pred_text or not gt_text:
        return 0.0
    
    edit_dist_norm = normalized_edit_distance(pred_text, gt_text)
    r_arg = 1.0 - edit_dist_norm
    
    return r_arg


def compute_swipe_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    try:
        pred_dir = get_swipe_direction(pred)
        gt_dir = get_swipe_direction(gt)
        
        if not pred_dir or not gt_dir or pred_dir != gt_dir:
            return 0.0
        
        r_arg = 0.5
        
        pred_coord1 = pred.get("coordinate", None)
        pred_coord2 = pred.get("coordinate2", None)
        gt_coord1 = gt.get("coordinate", None)
        gt_coord2 = gt.get("coordinate2", None)
        
        if pred_coord1 is None or pred_coord2 is None or gt_coord1 is None or gt_coord2 is None:
            return 0.0
        
        if not isinstance(pred_coord1, (list, tuple)) or len(pred_coord1) < 2:
            return 0.0
        if not isinstance(pred_coord2, (list, tuple)) or len(pred_coord2) < 2:
            return 0.0
        if not isinstance(gt_coord1, (list, tuple)) or len(gt_coord1) < 2:
            return 0.0
        if not isinstance(gt_coord2, (list, tuple)) or len(gt_coord2) < 2:
            return 0.0
        
        try:
            pred_dx = float(pred_coord2[0]) - float(pred_coord1[0])
            pred_dy = float(pred_coord2[1]) - float(pred_coord1[1])
            pred_magnitude = math.sqrt(pred_dx ** 2 + pred_dy ** 2)
            
            gt_dx = float(gt_coord2[0]) - float(gt_coord1[0])
            gt_dy = float(gt_coord2[1]) - float(gt_coord1[1])
            gt_magnitude = math.sqrt(gt_dx ** 2 + gt_dy ** 2)
            
            if gt_magnitude > 0 and pred_magnitude > 0:
                magnitude_ratio = min(pred_magnitude / gt_magnitude, gt_magnitude / pred_magnitude)
                magnitude_reward = 0.5 * magnitude_ratio
                r_arg += magnitude_reward
        except (ValueError, TypeError, IndexError, ZeroDivisionError):
            pass
        
        return max(0.0, min(1.0, r_arg))
    except Exception:
        return 0.0


def compute_system_button_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    pred_button = pred.get("button", "")
    gt_button = gt.get("button", "")
    
    if pred_button == gt_button:
        return 1.0
    else:
        return 0.0


def compute_terminate_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    pred_status = pred.get("status", "")
    gt_status = gt.get("status", "")
    
    if pred_status == gt_status:
        return 1.0
    else:
        return 0.0


def compute_action_match_reward(
    output_text: str,
    ground_truth_action: Optional[Dict[str, Any]] = None,
    click_threshold: float = 140.0,
    data_type: int = 0,
    use_continuous_reward: bool = True,
    lambda_arg: Optional[float] = None,
    click_tau: float = 60.0,
) -> float:
    if not output_text or not ground_truth_action:
        return 0.0
    
    pred_action = None
    if parse_model_output_to_qwen3 is not None:
        pred_action = parse_model_output_to_qwen3(output_text)
    
    if not pred_action:
        return 0.0
    
    if isinstance(ground_truth_action, str):
        try:
            ground_truth_action = json.loads(ground_truth_action)
        except json.JSONDecodeError:
            return 0.0
    
    gt_action = ground_truth_action
    
    if is_qwen3_action_match is None or is_qwen3_action_type_match is None:
        pred_action_type = pred_action.get("action", "").lower() if isinstance(pred_action, dict) else ""
        gt_action_type = gt_action.get("action", "").lower() if isinstance(gt_action, dict) else ""
        if pred_action_type == gt_action_type:
            return 0.3
        else:
            return 0.0
    
    if use_continuous_reward:
        try:
            if is_qwen3_action_type_match(pred_action, gt_action):
                r_type = 0.3
            else:
                r_type = 0.0
            
            if r_type == 0.0:
                return 0.0
            
            pred_action_name = pred_action.get("action", "").lower()
            gt_action_name = gt_action.get("action", "").lower()
            
            can_parse_params = True
            if pred_action_name in ["click", "long_press"]:
                pred_coord = pred_action.get("coordinate")
                gt_coord = gt_action.get("coordinate")
                if pred_coord is None or gt_coord is None:
                    can_parse_params = False
                else:
                    if (safe_normalize_coordinate(pred_coord, [0, 0]) is None or
                        safe_normalize_coordinate(gt_coord, [0, 0]) is None):
                        can_parse_params = False
            elif pred_action_name in ["type", "answer"]:
                pred_text = pred_action.get("text")
                gt_text = gt_action.get("text")
                if pred_text is None or gt_text is None:
                    can_parse_params = False
                elif not isinstance(pred_text, str) or not isinstance(gt_text, str):
                    can_parse_params = False
            elif pred_action_name == "swipe":
                pred_coord1 = pred_action.get("coordinate")
                pred_coord2 = pred_action.get("coordinate2")
                gt_coord1 = gt_action.get("coordinate")
                gt_coord2 = gt_action.get("coordinate2")
                if (pred_coord1 is None or pred_coord2 is None or
                    gt_coord1 is None or gt_coord2 is None):
                    can_parse_params = False
                else:
                    if (safe_normalize_coordinate(pred_coord1, [500, 500]) is None or
                        safe_normalize_coordinate(pred_coord2, [500, 500]) is None or
                        safe_normalize_coordinate(gt_coord1, [500, 500]) is None or
                        safe_normalize_coordinate(gt_coord2, [500, 500]) is None):
                        can_parse_params = False
            elif pred_action_name == "system_button":
                pred_button = pred_action.get("button")
                gt_button = gt_action.get("button")
                if pred_button is None or gt_button is None:
                    can_parse_params = False
                elif not isinstance(pred_button, str) or not isinstance(gt_button, str):
                    can_parse_params = False
            elif pred_action_name == "terminate":
                pred_status = pred_action.get("status")
                gt_status = gt_action.get("status")
                if pred_status is None or gt_status is None:
                    can_parse_params = False
                elif not isinstance(pred_status, str) or not isinstance(gt_status, str):
                    can_parse_params = False
            
            if not can_parse_params:
                return 0.0
            
            r_arg = 0.0
            try:
                if pred_action_name == "click" or pred_action_name == "long_press":
                    r_arg = compute_click_arg_reward(pred_action, gt_action, click_tau)
                elif pred_action_name in ["type", "answer"] and gt_action_name in ["type", "answer"]:
                    r_arg = compute_type_arg_reward(pred_action, gt_action)
                elif pred_action_name == "swipe" and gt_action_name == "swipe":
                    r_arg = compute_swipe_arg_reward(pred_action, gt_action)
                elif pred_action_name == "system_button" and gt_action_name == "system_button":
                    r_arg = compute_system_button_arg_reward(pred_action, gt_action)
                elif pred_action_name == "terminate" and gt_action_name == "terminate":
                    r_arg = compute_terminate_arg_reward(pred_action, gt_action)
                else:
                    if is_qwen3_action_match(pred_action, gt_action, click_threshold):
                        r_arg = 1.0
                    else:
                        r_arg = 0.0
            except Exception:
                r_arg = 0.0
            
            if lambda_arg is None:
                if data_type == 1:
                    lambda_arg = 3.0
                else:
                    lambda_arg = 1.0
            
            raw_score = r_type + lambda_arg * r_arg
            max_score = 0.3 + lambda_arg * 1.0
            normalized_score = min(raw_score / max_score, 1.0) if max_score > 0 else 0.0
            
            return normalized_score
        except Exception:
            return 0.0
    else:
        try:
            if is_qwen3_action_match(pred_action, gt_action, click_threshold):
                return 1.0
            elif is_qwen3_action_type_match(pred_action, gt_action):
                return 0.3
            else:
                return 0.0
        except Exception:
            return 0.0


def compute_thought_action_consistency_reward(
    output_text: str,
) -> float:
    try:
        if not output_text or not isinstance(output_text, str):
            return 0.5
        
        thought, action, tool_call = parse_output_text(output_text)
        
        if not thought or not tool_call:
            return 0.5
        
        thought_lower = thought.lower()
        action_lower = action.lower() if action else ""
        
        pred_action = tool_call
        if not isinstance(pred_action, dict):
            return 0.5
        
        pred_action_type = pred_action.get("action", "").lower()
        
        r_ta = 0.0
        
        trap_obscured_keywords = [
            "obscured", "obstructed", "blocked",
            "no longer visible", "not visible", "cannot see",
            "longer visible", "visible current", "visible altered",
            "target element", "intended action", "element intended",
            "action longer", "longer visible current", "visible current screen",
            "target element intended", "element intended action", "intended action longer",
            "action longer visible", "longer visible altered", "visible altered current",
            "altered current", "altered current screen",
            "impossible to perform", "cannot perform", "impossible",
            "has been altered", "changed", "altered",
            "screen is currently", "screen shows", "current screen is",
            "making it impossible", "so the task cannot", "task cannot be completed",
            "app is not visible", "element is not visible"
        ]
        
        trap_unrelated_keywords = [
            "unrelated to", "not related to", "does not match",
            "wrong screen", "incorrect screen",
            "different task", "other task",
            "current screen is for", "current screen is showing",
            "current screen related", "screen related",
            "which is unrelated", "which is not relevant", "not relevant to",
            "currently in", "but the task is", "but the task of",
            "screen showing", "showing search", "for searching", "for filtering"
        ]
        
        trap_navigate_keywords = [
            "navigate back", "go back", "return to",
            "need navigate", "need navigate back",
            "previous screen", "previous interface",
            "main screen", "home screen",
            "need to access", "need to find", "need to navigate"
        ]
        
        has_trap_obscured = any(keyword in thought_lower for keyword in trap_obscured_keywords)
        has_trap_unrelated = any(keyword in thought_lower for keyword in trap_unrelated_keywords)
        has_trap_navigate = any(keyword in thought_lower for keyword in trap_navigate_keywords)
        
        if has_trap_obscured or has_trap_unrelated or has_trap_navigate:
            if pred_action_type == "system_button":
                pred_button = pred_action.get("button", "").lower()
                if pred_button in ["back", "home"]:
                    r_ta = 0.85
                else:
                    r_ta = 0.5
            else:
                r_ta = -0.75
        
        if r_ta == 0.0:
            type_keywords = [
                "type", "enter", "input", "search term",
                "type text", "enter text", "input text",
                "search for", "search",
                "search bar", "bar keyboard", "keyboard active", "keyboard visible",
                "search bar keyboard", "bar keyboard active", "keyboard active indicating",
                "active indicating", "indicating enter", "enter search", "search query",
                "indicating enter search", "enter search query", "search bar active",
                "bar active", "active step", "keyboard visible indicating",
                "see search", "see search bar", "need type", "type search",
                "input field", "field active", "message input field",
                "need to type", "should type", "will type", "should enter",
                "i can see the search bar", "search bar is active",
                "keyboard is visible", "indicating type", "step enter"
            ]
            has_type_intent = any(keyword in thought_lower for keyword in type_keywords)
            
            if has_type_intent:
                if pred_action_type in ["type", "answer"]:
                    pred_text = pred_action.get("text", "").lower()
                    
                    text_match_score = 0.0
                    if pred_text:
                        pred_words = [w for w in pred_text.split() if len(w) > 2]
                        if pred_words:
                            matched_words = sum(1 for word in pred_words if word in thought_lower)
                            if matched_words > 0:
                                match_ratio = matched_words / len(pred_words)
                                text_match_score = 0.3 + 0.4 * match_ratio
                    
                    r_ta = max(-1.0, min(1.0, 0.4 + text_match_score))
                else:
                    r_ta = -0.65
        
        if r_ta == 0.0:
            click_keywords = [
                "click", "tap", "press", "select",
                "click on", "tap on", "press on",
                "select", "choose",
                "need click", "need select", "need access", "need open", "need use",
                "based instruction", "screen based", "screen based instruction",
                "button screen", "button screen based",
                "icon", "button", "link", "option", "menu", "options",
                "should click", "need to click", "will click",
                "i can see", "visible", "can see", "see search", "see add",
                "at the", "in the", "which is", "to add", "to open", "to access",
                "search bar", "search icon", "click search", "click search bar",
                "corner screen", "right corner", "right corner screen",
                "user request", "per user", "per user request",
                "logical step", "search results", "search result",
                "cart need", "email address", "heart icon", "three-dot menu", "bookmark icon"
            ]
            has_click_intent = any(keyword in thought_lower for keyword in click_keywords)
            
            if has_click_intent:
                if pred_action_type in ["click", "long_press"]:
                    r_ta = 0.75
                elif pred_action_type == "system_button":
                    if "back" in thought_lower or "home" in thought_lower or "system" in thought_lower:
                        r_ta = 0.5
                    else:
                        r_ta = 0.2
                else:
                    r_ta = -0.65
        
        if r_ta == 0.0:
            scroll_keywords = [
                "scroll", "swipe",
                "need scroll", "scroll page", "need scroll page",
                "scroll up", "scroll down", "scroll left", "scroll right",
                "swipe up", "swipe down", "swipe left", "swipe right",
                "need scroll down", "scroll down", "down page", "scroll down page",
                "view more", "see more", "explore", "view details", "scroll view",
                "need scroll view", "scroll list", "need scroll list",
                "view product", "read reviews", "reviews need", "reviews need scroll",
                "continue reading", "access app", "app drawer", "access app drawer",
                "shoes need", "shoes need scroll", "scroll find", "need scroll find",
                "view options", "app visible", "visible current", "current screen",
                "need to scroll", "should scroll", "will scroll",
                "to see", "to view", "to find", "to get", "to learn",
                "navigate through", "browse through", "explore more"
            ]
            has_scroll_intent = any(keyword in thought_lower for keyword in scroll_keywords)
            
            if has_scroll_intent:
                if pred_action_type == "swipe":
                    direction_match = 0.0
                    try:
                        if any(kw in thought_lower for kw in ["down", "scroll down", "swipe down"]):
                            pred_dir = get_swipe_direction(pred_action)
                            if pred_dir == "DOWN":
                                direction_match = 0.5
                        elif any(kw in thought_lower for kw in ["up", "scroll up", "swipe up"]):
                            pred_dir = get_swipe_direction(pred_action)
                            if pred_dir == "UP":
                                direction_match = 0.5
                        elif any(kw in thought_lower for kw in ["left", "scroll left", "swipe left"]):
                            pred_dir = get_swipe_direction(pred_action)
                            if pred_dir == "LEFT":
                                direction_match = 0.5
                        elif any(kw in thought_lower for kw in ["right", "scroll right", "swipe right"]):
                            pred_dir = get_swipe_direction(pred_action)
                            if pred_dir == "RIGHT":
                                direction_match = 0.5
                        else:
                            direction_match = 0.25
                    except Exception:
                        pass
                    
                    r_ta = max(-1.0, min(1.0, 0.4 + direction_match))
                else:
                    r_ta = -0.65
        
        if r_ta == 0.0:
            system_button_keywords = [
                "press back", "press home", "navigate back", "go back",
                "return to", "back to",
                "back button", "home button",
                "need to navigate", "need to access", "need to find",
                "currently in", "but the task", "but the task of",
                "go back to", "return to the"
            ]
            has_system_button_intent = any(keyword in thought_lower for keyword in system_button_keywords)
            
            if has_system_button_intent:
                if pred_action_type == "system_button":
                    pred_button = pred_action.get("button", "").lower()
                    if "back" in thought_lower and pred_button == "back":
                        r_ta = 0.8
                    elif "home" in thought_lower and pred_button == "home":
                        r_ta = 0.8
                    elif pred_button in ["back", "home"]:
                        r_ta = 0.65
                    else:
                        r_ta = 0.4
                else:
                    r_ta = -0.55
        
        if r_ta == 0.0:
            terminate_keywords = [
                "finish", "complete", "done", "finished", "completed",
                "task", "task complete", "task adding", "indicating task",
                "indicating task complete", "task is complete", "task is finished", "has been completed",
                "cart complete", "cart proceeding", "proceeding checkout", "checkout completed",
                "added cart", "user sign", "sign page", "page final", "final step",
                "already", "is already", "already in", "already set", "already off",
                "off position", "position indicating", "already off position",
                "toggle", "picture", "apps", "allowed", "location", "added", "message",
                "task of", "as indicated by", "as the user is now", "indicating that"
            ]
            has_terminate_intent = any(keyword in thought_lower for keyword in terminate_keywords)
            
            if has_terminate_intent:
                if pred_action_type == "terminate":
                    r_ta = 0.8
                else:
                    r_ta = -0.4
        
        if r_ta == 0.0:
            wait_keywords = [
                "loading", "wait", "wait for", "still loading",
                "search results", "results still", "results still loading",
                "still loading need", "loading need wait", "need wait",
                "wait content", "content appear", "wait content appear",
                "need to wait", "should wait", "will wait",
                "wait for the", "wait for content", "wait for the content"
            ]
            has_wait_intent = any(keyword in thought_lower for keyword in wait_keywords)
            
            if has_wait_intent:
                if pred_action_type == "wait":
                    r_ta = 0.75
                else:
                    r_ta = -0.3
        
        r_consistency_01 = (r_ta + 1.0) / 2.0
        return max(0.0, min(1.0, r_consistency_01))
        
    except Exception:
        return 0.5



def compute_score(
    reward_inputs: List[Dict[str, Any]],
    action_match_weight: float = 0.9,
    consistency_weight: float = 0.1,
    click_threshold: float = 140.0,
    use_consistency_reward: bool = True,
    use_continuous_reward: bool = True,
    click_tau: float = 60.0,
    use_action_type_weight: bool = False,
) -> List[Dict[str, float]]:
    scores = []
    
    for reward_input in reward_inputs:
        response = reward_input["response"]
        ground_truth = reward_input.get("ground_truth", "")
        
        data_type = reward_input.get("data_type", None)
        if data_type is not None:
            if isinstance(data_type, str):
                try:
                    data_type = int(data_type)
                except ValueError:
                    data_type = 0
            elif isinstance(data_type, (int, float)):
                data_type = int(data_type)
            else:
                data_type = 0
        else:
            data_type = 0
        
        label = None
        ground_truth_dict = {}
        
        if ground_truth:
            if isinstance(ground_truth, str):
                try:
                    label = json.loads(ground_truth)
                except json.JSONDecodeError:
                    if parse_model_output_to_qwen3 is not None:
                        label = parse_model_output_to_qwen3(ground_truth)
            elif isinstance(ground_truth, dict):
                ground_truth_dict = ground_truth
                label = ground_truth_dict.get("label", ground_truth)
            
            if data_type == 0 and isinstance(ground_truth, dict) and "data_type" in ground_truth:
                data_type_val = ground_truth["data_type"]
                if isinstance(data_type_val, str):
                    try:
                        data_type = int(data_type_val)
                    except ValueError:
                        data_type = 0
                elif isinstance(data_type_val, (int, float)):
                    data_type = int(data_type_val)
                else:
                    data_type = 0
        
        try:
            action_match_score = compute_action_match_reward(
                response, 
                label, 
                click_threshold,
                data_type=data_type,
                use_continuous_reward=use_continuous_reward,
                click_tau=click_tau,
            )
        except Exception:
            action_match_score = 0.0
        
        if use_consistency_reward:
            try:
                consistency_score = compute_thought_action_consistency_reward(response)
            except Exception:
                consistency_score = 0.0
        else:
            consistency_score = 0.0
        
        actual_action_match_weight = action_match_weight
        actual_consistency_weight = consistency_weight if use_consistency_reward else 0.0
        
        try:
            overall = (
                actual_action_match_weight * action_match_score +
                actual_consistency_weight * consistency_score
            )
            if math.isnan(overall) or math.isinf(overall):
                overall = 0.0
        except Exception:
            overall = 0.0
        
        scores.append({
            "overall": overall,
            "action_match": action_match_score,
            "consistency": consistency_score,
        })
    
    return scores
