# Copyright 2025 Xiaomi Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Faithful-GRPO 奖励函数（适配 EasyR1 框架）

实现3种奖励组件：
1. R_fmt: 格式守门员（检查Thought, Action, tool_call格式）
2. R_action_match: 动作匹配奖励（检查生成的动作是否与ground truth一致，使用Qwen3动作匹配标准）
3. R_embedding: Embedding相似度奖励（计算Thought和Action之间的余弦相似度，使用Qwen3-VL模型的hidden states）

数据格式（来自train_llamafactory.json）：
- output_text: "Thought: ...\nAction: ...\n<tool_call>\n{...}\n</tool_call>"
- thought: 推理过程
- action: 动作描述
- tool_call: 实际的动作（groundtruth）
- label: Ground Truth动作

注意：
- 模型输出格式应该是：Thought + Action + tool_call
- R_action_match 使用 Qwen3 动作匹配标准
- R_embedding 使用 Qwen3-VL 模型的 hidden states 计算 Thought 和 Action 之间的余弦相似度
- 可以通过 embedding_model_path 参数指定 Qwen3-VL 模型路径，或通过环境变量 QWEN3_VL_MODEL_PATH 设置
"""

import re
import json
import os
import sys
import math
from typing import Dict, Any, List, Optional, Tuple
import numpy as np
from PIL import Image

from qwen3_action_mapper import (
    parse_model_output_to_qwen3,
    is_qwen3_action_match,
    is_qwen3_action_type_match,
    get_swipe_direction,
    calculate_click_distance,
    safe_normalize_coordinate,
)

# Metadata
REWARD_NAME = "faithful_grpo"
REWARD_TYPE = "batch"

# ============================================================================
# ============================================================================

def get_image_size_from_reward_input(reward_input: Dict[str, Any]) -> Tuple[int, int]:
    """
    从 reward_input 中获取图片尺寸
    
    Args:
        reward_input: 奖励输入字典，可能包含：
            - images: 图片路径列表
            - image_width: 图片宽度（如果直接提供）
            - image_height: 图片高度（如果直接提供）
    
    Returns:
        (image_width, image_height) 元组，如果无法获取则返回默认值 (1080, 2400)
    """
    image_width = reward_input.get("image_width")
    image_height = reward_input.get("image_height")
    
    if image_width is None or image_height is None:
        images = reward_input.get("images", [])
        image_path = None
        if images:
            if isinstance(images, list) and len(images) > 0:
                image_path = images[0]
            elif isinstance(images, str):
                image_path = images
        
        if image_path and os.path.exists(image_path):
            try:
                img = Image.open(image_path)
                if image_width is None:
                    image_width = img.width
                if image_height is None:
                    image_height = img.height
            except Exception:
                pass
    
    if image_width is None:
        image_width = 1080
    if image_height is None:
        image_height = 2400
    
    return image_width, image_height


# ============================================================================
# ============================================================================

def parse_output_text(output_text: str) -> Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
    """
    解析output_text，提取Thought, Action, tool_call
    
    Args:
        output_text: 模型输出的完整文本
        
    Returns:
        (thought, action, tool_call_dict) 元组
    """
    if not output_text or not isinstance(output_text, str):
        return None, None, None
    
    thought = None
    action = None
    tool_call = None
    
    thought_pattern = r'Thought:\s*(.+?)(?=\nAction:|\n<tool_call>|\Z)'
    thought_match = re.search(thought_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if thought_match:
        thought = thought_match.group(1).strip()
    
    action_pattern = r'Action:\s*(.+?)(?=\n<tool_call>|\Z)'
    action_match = re.search(action_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if action_match:
        action = action_match.group(1).strip()
    
    tool_call_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
    tool_call_match = re.search(tool_call_pattern, output_text, re.DOTALL | re.IGNORECASE)
    if tool_call_match:
        try:
            tool_call_json = tool_call_match.group(1)
            tool_call = json.loads(tool_call_json)
            if isinstance(tool_call, dict) and "arguments" in tool_call:
                tool_call = tool_call["arguments"]
            elif isinstance(tool_call, dict) and "name" in tool_call:
                if "arguments" in tool_call:
                    tool_call = tool_call["arguments"]
        except json.JSONDecodeError:
            pass
    
    return thought, action, tool_call


def check_format(output_text: str) -> bool:
    """
    检查输出格式是否包含Thought, Action, tool_call
    
    Returns:
        True if format is correct
    """
    if not output_text or not isinstance(output_text, str):
        return False
    
    has_thought = bool(re.search(r'Thought:\s*', output_text, re.IGNORECASE))
    
    has_action = bool(re.search(r'Action:\s*', output_text, re.IGNORECASE))
    
    has_tool_call = bool(re.search(r'<tool_call>', output_text, re.IGNORECASE))
    
    return has_thought and has_action and has_tool_call


# ============================================================================
# ============================================================================

def normalized_edit_distance(text1: str, text2: str) -> float:
    """
    计算归一化的编辑距离（Levenshtein距离）
    
    Returns:
        归一化后的编辑距离，范围 [0, 1]，0 表示完全相同，1 表示完全不同
    """
    if not text1 and not text2:
        return 0.0
    if not text1 or not text2:
        return 1.0
    
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i-1] == text2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = min(
                    dp[i-1][j] + 1,
                    dp[i][j-1] + 1,
                    dp[i-1][j-1] + 1
                )
    
    edit_dist = dp[m][n]
    max_len = max(m, n)
    
    return edit_dist / max_len if max_len > 0 else 0.0


def compute_click_arg_reward(
    pred: Dict[str, Any], 
    gt: Dict[str, Any], 
    tau: float = 60.0,
    image_width: Optional[int] = None,
    image_height: Optional[int] = None
) -> float:
    """
    计算 click/long_press 动作的参数奖励（基于坐标距离，使用归一化后的坐标空间）
    
    Args:
        pred: 预测的动作
        gt: 真实动作
        tau: 距离衰减参数（归一化后的坐标空间，0-1000范围），默认 60.0
        image_width: 图片宽度（用于坐标归一化）
        image_height: 图片高度（用于坐标归一化）
    
    Returns:
        r_arg 分数，范围 [0, 1]
    """
    try:
        distance = calculate_click_distance(pred, gt, image_width, image_height)
        if distance is None:
            return 0.0
        
        if tau <= 0:
            tau = 60.0
        r_arg = math.exp(-distance / tau)
        return max(0.0, min(1.0, r_arg))
    except Exception:
        return 0.0


def compute_type_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    """
    计算 type/answer 动作的参数奖励（基于文本相似度）
    
    Args:
        pred: 预测的动作
        gt: 真实动作
    
    Returns:
        r_arg 分数，范围 [0, 1]
    """
    pred_text = pred.get("text", "")
    gt_text = gt.get("text", "")
    
    if not pred_text and not gt_text:
        return 1.0
    if not pred_text or not gt_text:
        return 0.0
    
    edit_dist_norm = normalized_edit_distance(pred_text, gt_text)
    r_arg = 1.0 - edit_dist_norm
    
    return r_arg


def compute_swipe_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    """
    计算 swipe 动作的参数奖励
    
    方向对：先给 0.5
    幅度接近：再给 0~0.5（按比例差衰减）
    
    Args:
        pred: 预测的动作
        gt: 真实动作
    
    Returns:
        r_arg 分数，范围 [0, 1]
    """
    try:
        pred_dir = get_swipe_direction(pred)
        gt_dir = get_swipe_direction(gt)
        
        if not pred_dir or not gt_dir or pred_dir != gt_dir:
            return 0.0
        
        r_arg = 0.5
        
        pred_coord1 = pred.get("coordinate", None)
        pred_coord2 = pred.get("coordinate2", None)
        gt_coord1 = gt.get("coordinate", None)
        gt_coord2 = gt.get("coordinate2", None)
        
        if pred_coord1 is None or pred_coord2 is None or gt_coord1 is None or gt_coord2 is None:
            return 0.0
        
        if not isinstance(pred_coord1, (list, tuple)) or len(pred_coord1) < 2:
            return 0.0
        if not isinstance(pred_coord2, (list, tuple)) or len(pred_coord2) < 2:
            return 0.0
        if not isinstance(gt_coord1, (list, tuple)) or len(gt_coord1) < 2:
            return 0.0
        if not isinstance(gt_coord2, (list, tuple)) or len(gt_coord2) < 2:
            return 0.0
        
        try:
            pred_dx = float(pred_coord2[0]) - float(pred_coord1[0])
            pred_dy = float(pred_coord2[1]) - float(pred_coord1[1])
            pred_magnitude = math.sqrt(pred_dx ** 2 + pred_dy ** 2)
            
            gt_dx = float(gt_coord2[0]) - float(gt_coord1[0])
            gt_dy = float(gt_coord2[1]) - float(gt_coord1[1])
            gt_magnitude = math.sqrt(gt_dx ** 2 + gt_dy ** 2)
            
            if gt_magnitude > 0 and pred_magnitude > 0:
                magnitude_ratio = min(pred_magnitude / gt_magnitude, gt_magnitude / pred_magnitude)
                magnitude_reward = 0.5 * magnitude_ratio
                r_arg += magnitude_reward
        except (ValueError, TypeError, IndexError, ZeroDivisionError):
            pass
        
        return max(0.0, min(1.0, r_arg))
    except Exception:
        return 0.0


def compute_system_button_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    """
    计算 system_button 动作的参数奖励
    
    Args:
        pred: 预测的动作
        gt: 真实动作
    
    Returns:
        r_arg 分数，1.0 如果 button 匹配，否则 0.0
    """
    pred_button = pred.get("button", "")
    gt_button = gt.get("button", "")
    
    if pred_button == gt_button:
        return 1.0
    else:
        return 0.0


def compute_terminate_arg_reward(pred: Dict[str, Any], gt: Dict[str, Any]) -> float:
    """
    计算 terminate 动作的参数奖励
    
    Args:
        pred: 预测的动作
        gt: 真实动作
    
    Returns:
        r_arg 分数，1.0 如果 status 匹配，否则 0.0
    """
    pred_status = pred.get("status", "")
    gt_status = gt.get("status", "")
    
    if pred_status == gt_status:
        return 1.0
    else:
        return 0.0


# ============================================================================
# ============================================================================

def soft_high_sim_penalty(sim: float, tau: float = 0.40, temp: float = 0.08) -> float:
    """Return gate in [0,1], only meaningful when sim > tau.
    
    注意：gate/阈值计算在 fp32 做，确保数值精度。
    """
    sim = float(sim)
    tau = float(tau)
    temp = float(temp)
    x = (sim - tau) / temp
    gate = 1.0 / (1.0 + math.exp(-x))  # sigmoid
    return float(gate)

def emb_reward(sim: float, lam: float = 0.10, tau: float = 0.40, temp: float = 0.08) -> float:
    sim = float(sim)
    lam = float(lam)
    tau = float(tau)
    temp = float(temp)
    g = soft_high_sim_penalty(sim, tau=tau, temp=temp)
    return float(-lam * (g * g))


# ============================================================================
# ============================================================================

def compute_format_reward(output_text: str) -> float:
    """
    R_fmt: 格式守门员
    
    评分：
    - 格式完美（包含Thought, Action, tool_call）：+0.1
    - 格式错误：-1.0（直接截断）
    
    Args:
        output_text: 模型输出的完整文本
    """
    if not output_text or not isinstance(output_text, str):
        return -1.0
    
    if not check_format(output_text):
        return -1.0
    
    return 0.1


def compute_action_match_reward(
    output_text: str,
    ground_truth_action: Optional[Dict[str, Any]] = None,
    click_threshold: float = 140.0,
    use_continuous_reward: bool = True,
    click_tau: float = 60.0,
    image_width: Optional[int] = None,
    image_height: Optional[int] = None,
) -> float:
    """
    R_action_match: 动作匹配奖励（最重要的奖励）
    
    检查生成的动作是否与 ground truth 一致（使用 Qwen3 动作匹配标准）
    
    使用"固定底分"模式，奖励区间划分为两个明确的"领地"：
    - [0, 0.3): 类型错误区（类型不匹配时返回 0.0）
    - [0.3, 1.0]: 类型正确区（类型匹配时，通过参数匹配度 r_arg 在这个区间内滑动）
    
    如果 use_continuous_reward=True，使用固定底分公式：R = 0.3 + 0.7 × r_arg
    - 类型不匹配：返回 0.0（在 [0, 0.3) 区间）
    - 类型匹配：返回 0.3 + 0.7 × r_arg（在 [0.3, 1.0] 区间）
    - r_arg: 参数匹配分数（范围 [0, 1]）
    
    如果 use_continuous_reward=False，使用三档奖励：
    - 完全匹配（类型和内容都匹配）：+1.0（最高奖励）
    - 类型匹配但内容不匹配：+0.3（给予正奖励）
    - 类型不匹配：0.0（无奖励）
    
    Args:
        output_text: 模型输出的完整文本（格式：Thought + Action + tool_call）
        ground_truth_action: Ground Truth动作（Qwen3 格式字典，与 gt_action 字段一致）
        click_threshold: click 动作的距离阈值（像素），默认 140.0
        use_continuous_reward: 是否使用连续分数奖励，默认 True
        click_tau: click 动作的距离衰减参数（像素），默认 60.0
    
    Returns:
        动作匹配分数（范围 [0, 1]，不进行归一化）
    """
    if not output_text or not ground_truth_action:
        return 0.0
    
    pred_action = None
    if parse_model_output_to_qwen3 is not None:
        pred_action = parse_model_output_to_qwen3(output_text)
    
    if not pred_action:
        return 0.0
    
    if isinstance(ground_truth_action, str):
        try:
            ground_truth_action = json.loads(ground_truth_action)
        except json.JSONDecodeError:
            return 0.0
    
    gt_action = ground_truth_action
    
    if is_qwen3_action_match is None or is_qwen3_action_type_match is None:
        pred_action_type = pred_action.get("action", "").lower() if isinstance(pred_action, dict) else ""
        gt_action_type = gt_action.get("action", "").lower() if isinstance(gt_action, dict) else ""
        if pred_action_type == gt_action_type:
            return 0.3
        else:
            return 0.0
    
    if use_continuous_reward:
        try:
            if not is_qwen3_action_type_match(pred_action, gt_action):
                return 0.0
            
            pred_action_name = pred_action.get("action", "").lower()
            gt_action_name = gt_action.get("action", "").lower()
            
            can_parse_params = True
            if pred_action_name in ["click", "long_press"]:
                pred_coord = pred_action.get("coordinate")
                gt_coord = gt_action.get("coordinate")
                if pred_coord is None or gt_coord is None:
                    can_parse_params = False
                else:
                    if (safe_normalize_coordinate(pred_coord, [0, 0]) is None or
                        safe_normalize_coordinate(gt_coord, [0, 0]) is None):
                        can_parse_params = False
            elif pred_action_name in ["type", "answer"]:
                pred_text = pred_action.get("text")
                gt_text = gt_action.get("text")
                if pred_text is None or gt_text is None:
                    can_parse_params = False
                elif not isinstance(pred_text, str) or not isinstance(gt_text, str):
                    can_parse_params = False
            elif pred_action_name == "swipe":
                pred_coord1 = pred_action.get("coordinate")
                pred_coord2 = pred_action.get("coordinate2")
                gt_coord1 = gt_action.get("coordinate")
                gt_coord2 = gt_action.get("coordinate2")
                if (pred_coord1 is None or pred_coord2 is None or
                    gt_coord1 is None or gt_coord2 is None):
                    can_parse_params = False
                else:
                    if (safe_normalize_coordinate(pred_coord1, [500, 500]) is None or
                        safe_normalize_coordinate(pred_coord2, [500, 500]) is None or
                        safe_normalize_coordinate(gt_coord1, [500, 500]) is None or
                        safe_normalize_coordinate(gt_coord2, [500, 500]) is None):
                        can_parse_params = False
            elif pred_action_name == "system_button":
                pred_button = pred_action.get("button")
                gt_button = gt_action.get("button")
                if pred_button is None or gt_button is None:
                    can_parse_params = False
                elif not isinstance(pred_button, str) or not isinstance(gt_button, str):
                    can_parse_params = False
            elif pred_action_name == "terminate":
                pred_status = pred_action.get("status")
                gt_status = gt_action.get("status")
                if pred_status is None or gt_status is None:
                    can_parse_params = False
                elif not isinstance(pred_status, str) or not isinstance(gt_status, str):
                    can_parse_params = False
            
            if not can_parse_params:
                return 0.0
            
            r_arg = 0.0
            try:
                if pred_action_name == "click" or pred_action_name == "long_press":
                    r_arg = compute_click_arg_reward(pred_action, gt_action, click_tau, image_width, image_height)
                elif pred_action_name in ["type", "answer"] and gt_action_name in ["type", "answer"]:
                    r_arg = compute_type_arg_reward(pred_action, gt_action)
                elif pred_action_name == "swipe" and gt_action_name == "swipe":
                    r_arg = compute_swipe_arg_reward(pred_action, gt_action)
                elif pred_action_name == "system_button" and gt_action_name == "system_button":
                    r_arg = compute_system_button_arg_reward(pred_action, gt_action)
                elif pred_action_name == "terminate" and gt_action_name == "terminate":
                    r_arg = compute_terminate_arg_reward(pred_action, gt_action)
                else:
                    if is_qwen3_action_match(pred_action, gt_action, click_threshold, image_width, image_height):
                        r_arg = 1.0
                    else:
                        r_arg = 0.0
            except Exception:
                r_arg = 0.0
            
            
            final_score = 0.3 + 0.7 * r_arg
            
            return final_score
        except Exception:
            return 0.0
    else:
        try:
            if is_qwen3_action_match(pred_action, gt_action, click_threshold, image_width, image_height):
                return 1.0
            elif is_qwen3_action_type_match(pred_action, gt_action):
                return 0.3
            else:
                return 0.0
        except Exception:
            return 0.0


def compute_embedding_reward(
    input_action_sim: float,
    action_match_score: Optional[float],
    tau: float = 0.40,
    lam: float = 0.08,
    full_match_thresh: float = 0.85,
):
    """
    极简 embedding 惩罚：
    - 只惩罚 input-action 相似度异常高
    - full_match 时完全关闭（action_match_score >= full_match_thresh）
    - 不引入连续噪声，不参与精细排序
    
    Args:
        input_action_sim: input-action 相似度
        action_match_score: 动作匹配分数
        tau: 相似度阈值，默认 0.40
        lam: 惩罚强度，默认 0.08
        full_match_thresh: full_match 阈值，默认 0.85（>= 此值时关闭惩罚）
    """
    if action_match_score is not None and action_match_score >= full_match_thresh:
        return 0.0

    if input_action_sim > tau:
        return -lam
    else:
        return 0.0


# ============================================================================
# ============================================================================

def compute_score(
    reward_inputs: List[Dict[str, Any]],
    format_weight: float = 0.0,
    action_match_weight: float = 0.7,
    embedding_weight: float = 0.3,
    click_threshold: float = 140.0,
    use_embedding_reward: bool = True,
    use_continuous_reward: bool = True,
    click_tau: float = 60.0,
    use_action_type_weight: bool = False,
    emb_tau: float = 0.40,
    emb_lam: float = 0.08,
    full_match_thresh: float = 0.85,
) -> List[Dict[str, float]]:
    """
    EasyR1 框架的奖励函数接口
    
    计算多个奖励组件的加权和作为 overall 分数（不归一化，直接使用原始权重）
    
    使用预计算embedding模式：设置config中的compute_embeddings=True，框架会在rollout后
    自动计算embedding并传入reward_input中的input_embedding/thought_embedding/action_embedding字段
    
    奖励组件：
    - action_match: 动作匹配奖励，范围 [0, 1]
    - embedding: Embedding 惩罚项，负值（不参与权重计算，直接加到总分中）
    
    Embedding惩罚采用极简策略：
    - 只惩罚 input-action 相似度异常高
    - full_match 时完全关闭（action_match >= full_match_thresh 时不惩罚）
    - 不引入连续噪声，不参与精细排序
    - 如果 input_action_sim > tau，则惩罚 -lam
    
    总分计算：
    - overall = action_match_weight * action_match_score + embedding_score
    - embedding_score 作为惩罚项直接加到总分中（不参与权重计算）
    - 权重不归一化，直接使用原始权重
    
    Args:
        reward_inputs: 奖励输入列表，每个元素包含：
            - response: 模型输出的响应文本
            - response_length: 响应长度
            - ground_truth: Ground Truth信息（JSON字符串或字典），包含：
                - label: Ground Truth动作
            - input_embedding: 预计算的input embedding（当compute_embeddings=True时由框架传入）
            - thought_embedding: 预计算的thought embedding（当compute_embeddings=True时由框架传入）
            - action_embedding: 预计算的action embedding（当compute_embeddings=True时由框架传入）
        format_weight: 格式奖励的权重，默认0.0（已移除，保留参数以兼容）
        action_match_weight: 动作匹配奖励的权重，默认0.7（最重要的奖励）
        embedding_weight: embedding 权重（已废弃，embedding 作为惩罚项不参与权重计算，保留参数以兼容）
        click_threshold: click 动作的距离阈值（像素），默认140.0
        use_embedding_reward: 是否使用embedding作为惩罚项，默认True。如果为False，将跳过embedding计算（embedding_score=0.0）
        use_continuous_reward: 是否使用连续分数奖励（r_type + λ * r_arg），默认True
        click_tau: click 动作的距离衰减参数（像素），默认60.0
        use_action_type_weight: 是否根据动作类型应用权重（少数类动作做对时给更高奖励），默认False（暂未实现）
        emb_tau: embedding 相似度阈值，默认0.40
        emb_lam: embedding 惩罚强度，默认0.08（最高扣 0.08，不会压过 action_match）
        full_match_thresh: full_match 阈值，默认0.85（action_match_score >= 此值时关闭 embedding 惩罚）
    
    Returns:
        奖励分数列表，每个元素包含：
            - overall: 总体奖励分数
            - format: 格式奖励分数（始终为0，已移除）
            - action_match: 动作匹配奖励分数
            - embedding: embedding相似度奖励分数（惩罚项，负值）
    """
    scores = []
    
    for reward_input in reward_inputs:
        response = reward_input["response"]
        ground_truth = reward_input.get("ground_truth", "")
        # print(f'response: {response}')
        input_embedding = reward_input.get("input_embedding", None)
        thought_embedding = reward_input.get("thought_embedding", None)
        action_embedding = reward_input.get("action_embedding", None)
        
        image_width, image_height = get_image_size_from_reward_input(reward_input)
        
        data_type = reward_input.get("data_type", None)
        if data_type is not None:
            if isinstance(data_type, str):
                try:
                    data_type = int(data_type)
                except ValueError:
                    data_type = 0
            elif isinstance(data_type, (int, float)):
                data_type = int(data_type)
            else:
                data_type = 0
        else:
            data_type = 0
        
        label = None
        ground_truth_dict = {}
        
        if ground_truth:
            if isinstance(ground_truth, str):
                try:
                    label = json.loads(ground_truth)
                except json.JSONDecodeError:
                    if parse_model_output_to_qwen3 is not None:
                        label = parse_model_output_to_qwen3(ground_truth)
            elif isinstance(ground_truth, dict):
                ground_truth_dict = ground_truth
                label = ground_truth_dict.get("label", ground_truth)
            
            if data_type == 0 and isinstance(ground_truth, dict) and "data_type" in ground_truth:
                data_type_val = ground_truth["data_type"]
                if isinstance(data_type_val, str):
                    try:
                        data_type = int(data_type_val)
                    except ValueError:
                        data_type = 0
                elif isinstance(data_type_val, (int, float)):
                    data_type = int(data_type_val)
                else:
                    data_type = 0
              
        try:
            action_match_score = compute_action_match_reward(
                response, 
                label, 
                click_threshold,
                use_continuous_reward=use_continuous_reward,
                click_tau=click_tau,
                image_width=image_width,
                image_height=image_height,
            )
        except Exception:
            action_match_score = 0.0
        
        embedding_score = 0.0
        if use_embedding_reward:
            try:
                input_action_sim = 0.0
                if input_embedding is not None and action_embedding is not None:
                    try:
                        input_emb_norm = input_embedding / (np.linalg.norm(input_embedding) + 1e-8)
                        action_emb_norm = action_embedding / (np.linalg.norm(action_embedding) + 1e-8)
                        input_action_sim = float(np.dot(input_emb_norm, action_emb_norm))
                    except Exception:
                        input_action_sim = 0.0
                
                embedding_score = compute_embedding_reward(
                    input_action_sim=input_action_sim,
                    action_match_score=action_match_score,
                    tau=emb_tau,
                    lam=emb_lam,
                    full_match_thresh=full_match_thresh,
                )
            except Exception:
                embedding_score = 0.0
        
        
        try:
            overall = (
                action_match_weight * action_match_score +
                embedding_score
            )
            if math.isnan(overall) or math.isinf(overall):
                overall = 0.0
        except Exception:
            overall = 0.0
        
        scores.append({
            "overall": overall,
            "action_match": action_match_score,
            "embedding": embedding_score,
        })
    
    return scores
