#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用分块并行方式评估test.json数据（不使用vLLM）
基于 transformers 和 multiprocessing 实现
适配 OS-Atlas 模型和数据格式

数据格式（OS-Atlas）：
- images: 图片路径列表
- messages: 包含system和user消息（可能包含assistant消息作为ground truth）
- label: groundtruth action（OS-Atlas格式，例如 "action:\nPRESS_BACK"）
- episode_id, step_id, data_type
"""

import os
import sys
import json
import math
import copy
import argparse
import traceback
from typing import Dict, Any, List, Optional, Tuple
from operator import itemgetter
from itertools import groupby
from collections import defaultdict
from tqdm import tqdm

# 添加项目路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, current_dir)
sys.path.insert(0, project_root)

# 导入评估工具函数
from qwen3_action_mapper import (
    get_qwen3_action_type,
    is_qwen3_action_type_match,
    is_qwen3_action_match,
    calculate_click_distance,
    parse_model_output_to_qwen3,
)

# 导入 OS-Atlas 相关模块
try:
    from utils.initial_agent import init_agent
    from utils.result_preprocess import OS_ATLAS_RES_PRE_PROCESS
    OS_ATLAS_AVAILABLE = True
except ImportError:
    OS_ATLAS_AVAILABLE = False
    print("警告: OS-Atlas 模块不可用，将无法处理 OS-Atlas 格式数据")


def convert_os_atlas_action_to_qwen3(os_atlas_action: str, image_size: Optional[Tuple[int, int]] = None) -> Optional[Dict[str, Any]]:
    """
    将 OS-Atlas 格式的 action 转换为 qwen3 格式
    
    Args:
        os_atlas_action: OS-Atlas 格式的 action 字符串，例如 "CLICK <point>[[101, 872]]</point>"
        image_size: 图片尺寸 (width, height)，用于坐标归一化
    
    Returns:
        qwen3 格式的 action 字典，如果转换失败返回 None
    """
    if not os_atlas_action or not isinstance(os_atlas_action, str):
        return None
    
    import re
    
    # 提取 action 类型和参数
    action_str = os_atlas_action.strip()
    
    # CLICK <point>[[x, y]]</point> 或 CLICK [[x, y]]
    click_match = re.search(r"CLICK\s*(?:<point>)?\[\[([^\]]+)\]\]", action_str, re.IGNORECASE)
    if click_match:
        coords_str = click_match.group(1)
        try:
            x, y = map(float, coords_str.split(','))
            # OS-Atlas 使用像素坐标，需要归一化到 0-1000
            if image_size:
                width, height = image_size
                x_norm = (x / width) * 1000.0 if width > 0 else x
                y_norm = (y / height) * 1000.0 if height > 0 else y
            else:
                # 如果没有图片尺寸，假设坐标已经是归一化的
                x_norm, y_norm = x, y
            return {"action": "click", "coordinate": [x_norm, y_norm]}
        except (ValueError, IndexError):
            return None
    
    # LONG_CLICK <point>[[x, y]]</point> 或 LONG_CLICK [[x, y]]
    long_click_match = re.search(r"LONG_CLICK\s*(?:<point>)?\[\[([^\]]+)\]\]", action_str, re.IGNORECASE)
    if long_click_match:
        coords_str = long_click_match.group(1)
        try:
            x, y = map(float, coords_str.split(','))
            if image_size:
                width, height = image_size
                x_norm = (x / width) * 1000.0 if width > 0 else x
                y_norm = (y / height) * 1000.0 if height > 0 else y
            else:
                x_norm, y_norm = x, y
            return {"action": "long_press", "coordinate": [x_norm, y_norm]}
        except (ValueError, IndexError):
            return None
    
    # TYPE [text]
    type_match = re.search(r"TYPE\s*\[([^\]]+)\]", action_str, re.IGNORECASE)
    if type_match:
        text = type_match.group(1).strip()
        return {"action": "type", "text": text}
    
    # SCROLL [direction]
    scroll_match = re.search(r"SCROLL\s*\[([^\]]+)\]", action_str, re.IGNORECASE)
    if scroll_match:
        direction = scroll_match.group(1).strip().upper()
        # 将 direction 转换为坐标
        if direction == "DOWN":
            coord1, coord2 = [500, 500], [500, 700]
        elif direction == "UP":
            coord1, coord2 = [500, 700], [500, 500]
        elif direction == "RIGHT":
            coord1, coord2 = [500, 500], [700, 500]
        elif direction == "LEFT":
            coord1, coord2 = [700, 500], [500, 500]
        else:
            coord1, coord2 = [500, 500], [500, 700]  # 默认向下
        return {"action": "swipe", "coordinate": coord1, "coordinate2": coord2}
    
    # PRESS_BACK
    if re.search(r"PRESS_BACK", action_str, re.IGNORECASE):
        return {"action": "system_button", "button": "back"}
    
    # PRESS_HOME
    if re.search(r"PRESS_HOME", action_str, re.IGNORECASE):
        return {"action": "system_button", "button": "home"}
    
    # ENTER
    if re.search(r"^ENTER$", action_str, re.IGNORECASE):
        return {"action": "system_button", "button": "enter"}
    
    # COMPLETE
    if re.search(r"COMPLETE", action_str, re.IGNORECASE):
        return {"action": "terminate", "status": "finished"}
    
    return None


def convert_gt_action_to_qwen3(gt_action: Dict[str, Any]) -> Dict[str, Any]:
    """
    将gt_action转换为qwen3格式（已经是qwen3格式，直接返回）
    
    Args:
        gt_action: groundtruth action字典
    
    Returns:
        qwen3格式的action字典
    """
    if not gt_action:
        return {}
    
    # gt_action已经是qwen3格式，直接返回
    return gt_action.copy()


def extract_messages_from_sample(sample: Dict[str, Any], is_os_atlas_format: bool = False) -> tuple:
    """
    从样本中提取messages字段的system和user消息，并构建完整的messages列表
    
    Args:
        sample: 样本字典
        is_os_atlas_format: 是否为 OS-Atlas 格式
    
    Returns:
        (messages_list, image_path) 元组
    """
    messages = sample.get("messages", sample.get("Messages", []))
    images = sample.get("images", [])
    image_path = images[0] if images else None
    
    # OS-Atlas 格式：messages 可能已经包含完整的格式，需要过滤掉 assistant 消息
    if is_os_atlas_format:
        # 过滤掉 assistant 消息（ground truth），只保留 system 和 user 消息用于推理
        filtered_messages = [msg for msg in messages if msg.get("role") != "assistant"]
        messages = filtered_messages
    
    # 构建messages列表
    messages_list = []
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "")
        
        # 处理 content 可能是列表的情况（UI-TARS 格式）
        if isinstance(content, list):
            # 如果已经是列表格式，直接使用
            messages_list.append({
                "role": role,
                "content": content
            })
        elif role == "system":
            messages_list.append({
                "role": "system",
                "content": [{"type": "text", "text": content}]
            })
        elif role == "user":
            user_content = [{"type": "text", "text": content}]
            # 如果有图片，添加图片
            if image_path and os.path.exists(image_path):
                user_content.append({"type": "image", "image": image_path})
            messages_list.append({
                "role": "user",
                "content": user_content
            })
    
    return messages_list, image_path


class TestDataEvaluator:
    """使用分块并行方式评估test.json数据（不使用vLLM）"""
    
    def __init__(self, config):
        self.config = config
        self.test_json_path = config.test_json_path
        self.result_save_path = getattr(config, 'result_save_path', "evaluation_results.json")
        self.model_path = getattr(config, 'model_path', "/data2/models/Qwen3-VL/Qwen3-VL-8B-Instruct")
        self.base_model_path = getattr(config, 'base_model_path', None)  # 基础模型路径（用于加载配置文件）
        self.max_new_tokens = getattr(config, 'max_new_tokens', 512)
        self.click_threshold = getattr(config, 'click_threshold', 140.0)
        self.model_name = getattr(config, 'model_name', None)  # 模型名称，用于判断是否使用 UI-TARS
        
        # 检测数据格式（检查是否有 label 字段且格式为 OS-Atlas）
        self.is_os_atlas_format = False
        if self.test_json_path and os.path.exists(self.test_json_path):
            try:
                with open(self.test_json_path, 'r', encoding='utf-8') as f:
                    sample_data = json.load(f)
                    if isinstance(sample_data, list) and len(sample_data) > 0:
                        first_sample = sample_data[0]
                        # 如果有 label 字段且没有 gt_action 字段，检查是否为 OS-Atlas 格式
                        if 'label' in first_sample and 'gt_action' not in first_sample:
                            label = first_sample.get('label', '')
                            # OS-Atlas 格式通常包含 "action:\n" 前缀和 CLICK, TYPE, SCROLL 等动作
                            if 'action:' in label.lower() and any(keyword in label.upper() for keyword in ['CLICK', 'TYPE', 'SCROLL', 'PRESS_BACK', 'PRESS_HOME', 'COMPLETE']):
                                self.is_os_atlas_format = True
                                print("检测到 OS-Atlas 数据格式（包含 label 字段且格式匹配）")
            except Exception as e:
                print(f"警告: 无法检测数据格式: {e}")
        
        # 如果模型名称包含 OS-Atlas 或 OS_Atlas，强制使用 OS-Atlas 格式
        if self.model_name and ("OS-Atlas" in self.model_name or "OS_Atlas" in self.model_name):
            self.is_os_atlas_format = True
            print(f"根据模型名称 {self.model_name} 判断为 OS-Atlas 格式")
        
        # CUDA 设备
        device_ids_str = getattr(config, 'device_ids', "[0]")
        try:
            device_ids = eval(device_ids_str)
        except:
            device_ids = [0]
        
        self.agent_count = getattr(config, 'agent_count', 1)
        agent_device_count = math.ceil(len(device_ids) / self.agent_count)
        self.device_ids = [
            device_ids[i * agent_device_count: (i + 1) * agent_device_count] 
            for i in range(self.agent_count)
        ]
        
        # 批处理大小（用于HF推理加速）
        self.batch_size = getattr(config, 'batch_size', 8)
        
        print(f"设备分配: 总共 {len(device_ids)} 个设备ID, {self.agent_count} 个 agent")
        for i, dev_ids in enumerate(self.device_ids):
            print(f"  Agent {i}: {dev_ids}")
        print(f"批处理大小: {self.batch_size}")
    
    def run_inference(
        self,
        model,
        processor,
        messages: List[Dict[str, Any]],
        max_new_tokens: int = 512,
    ) -> str:
        """运行单次推理（保留用于兼容性）"""
        responses = self.run_inference_batch(
            model, processor, [messages], max_new_tokens
        )
        return responses[0]
    
    def run_inference_batch(
        self,
        model,
        processor,
        messages_list: List[List[Dict[str, Any]]],
        max_new_tokens: int = 512,
    ) -> List[str]:
        """批量运行推理（性能优化）"""
        import torch
        from qwen_vl_utils import process_vision_info
        
        if not messages_list:
            return []
        
        # 批量处理所有messages
        texts = []
        all_image_inputs = []
        
        for messages in messages_list:
            # Apply chat template
            text = processor.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            texts.append(text)
            
            # 处理视觉信息（只处理图片，不处理视频）
            image_inputs, _ = process_vision_info(
                messages,
                return_video_kwargs=False,
            )
            all_image_inputs.append(image_inputs)
        
        # 准备批量输入（不传递视频相关参数）
        processor_kwargs = {
            "text": texts,
            "images": all_image_inputs,
            "padding": True,
            "return_tensors": "pt",
        }
        
        inputs = processor(**processor_kwargs)
        inputs = inputs.to(model.device)
        
        # 批量生成
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
            )
        
        # 批量解码：只解码生成的部分（去掉输入部分）
        # 获取输入序列长度，只解码生成的新token
        input_ids = inputs.input_ids
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(input_ids, generated_ids)
        ]
        
        # 解码生成的部分
        responses = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        
        return responses
    
    def process_chunk(
        self, 
        agent_id, 
        device_ids, 
        chunk, 
        chunk_start,
        pred_results, 
        list_lock, 
        progress_counter, 
        progress_lock,
        model_loaded_event, 
        status_queue, 
        model_path,
        base_model_path,
        max_new_tokens,
        click_threshold,
        batch_size,
        is_os_atlas_format=False,
        model_name=None,
    ):
        """处理一个数据块（多进程）"""
        import torch
        
        torch.cuda.empty_cache()
        torch.cuda.init()
        
        # 如果是 OS-Atlas 格式，使用 OS-Atlas agent
        if is_os_atlas_format and OS_ATLAS_AVAILABLE:
            return self._process_chunk_os_atlas(
                agent_id, device_ids, chunk, chunk_start, pred_results,
                list_lock, progress_counter, progress_lock, model_loaded_event,
                status_queue, model_path, max_new_tokens, click_threshold, model_name
            )
        
        # 标准格式，使用 transformers 模型
        from transformers import AutoProcessor, AutoModelForImageTextToText
        
        # 加载模型
        try:
            # 如果提供了基础模型路径，使用基础模型的配置，训练后模型的权重
            if base_model_path and os.path.exists(base_model_path):
                print(f"Agent {agent_id} 使用混合加载模式:")
                print(f"  配置文件来源: {base_model_path}")
                print(f"  权重文件来源: {model_path}")
                config_source_path = base_model_path
                weights_source_path = model_path
            else:
                print(f"Agent {agent_id} 正在加载模型: {model_path}")
                config_source_path = model_path
                weights_source_path = model_path
            
            # 先修复配置文件中的 rope_scaling 问题，然后再加载配置
            import json
            import tempfile
            import shutil
            from transformers import AutoConfig
            
            # 辅助函数：修复 rope_scaling 配置字典
            def fix_rope_scaling_dict(rope_scaling):
                """修复 rope_scaling 配置字典，确保包含 rope_type 键"""
                if rope_scaling is None:
                    return None
                if not isinstance(rope_scaling, dict):
                    return rope_scaling
                
                # 创建副本以避免修改原始字典
                fixed = rope_scaling.copy()
                
                # 如果存在 'type' 键但没有 'rope_type' 键，添加 'rope_type'
                if 'type' in fixed and 'rope_type' not in fixed:
                    fixed['rope_type'] = fixed['type']
                # 如果既没有 'type' 也没有 'rope_type'，根据错误信息可能是 dynamic 类型
                elif 'rope_type' not in fixed:
                    fixed['rope_type'] = 'dynamic'
                
                return fixed
            
            # 读取并修复配置文件
            config_path = os.path.join(config_source_path, 'config.json')
            temp_config_path = None
            temp_dir = None
            
            if os.path.exists(config_path):
                with open(config_path, 'r', encoding='utf-8') as f:
                    config_dict = json.load(f)
                
                original_dict = json.dumps(config_dict, sort_keys=True)
                
                # 修复顶层 rope_scaling
                if 'rope_scaling' in config_dict:
                    config_dict['rope_scaling'] = fix_rope_scaling_dict(config_dict['rope_scaling'])
                
                # 修复 text_config 中的 rope_scaling（Qwen3-VL 模型）
                if 'text_config' in config_dict and isinstance(config_dict['text_config'], dict):
                    if 'rope_scaling' in config_dict['text_config']:
                        config_dict['text_config']['rope_scaling'] = fix_rope_scaling_dict(
                            config_dict['text_config']['rope_scaling']
                        )
                
                # 检查是否有修改
                new_dict = json.dumps(config_dict, sort_keys=True)
                if original_dict != new_dict:
                    # 创建临时配置文件
                    temp_dir = tempfile.mkdtemp()
                    temp_config_path = os.path.join(temp_dir, 'config.json')
                    with open(temp_config_path, 'w', encoding='utf-8') as f:
                        json.dump(config_dict, f, indent=2, ensure_ascii=False)
                    # 使用临时配置文件路径加载配置
                    config = AutoConfig.from_pretrained(temp_config_path, trust_remote_code=True)
                else:
                    # 没有修改，直接加载
                    config = AutoConfig.from_pretrained(config_source_path, trust_remote_code=True)
            else:
                # 如果配置文件不存在，直接加载配置（可能会失败，但保持原有行为）
                config = AutoConfig.from_pretrained(config_source_path, trust_remote_code=True)
                
                # 尝试修复已加载的配置
                if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
                    if isinstance(config.rope_scaling, dict):
                        config.rope_scaling = fix_rope_scaling_dict(config.rope_scaling)
                
                if hasattr(config, 'text_config') and config.text_config is not None:
                    if hasattr(config.text_config, 'rope_scaling') and config.text_config.rope_scaling is not None:
                        if isinstance(config.text_config.rope_scaling, dict):
                            config.text_config.rope_scaling = fix_rope_scaling_dict(config.text_config.rope_scaling)
            
            # processor 也从基础模型加载（如果提供了基础模型路径）
            processor = AutoProcessor.from_pretrained(config_source_path, trust_remote_code=True)
            # 设置tokenizer的padding_side为'left'（decoder-only模型需要左填充）
            if hasattr(processor, 'tokenizer'):
                processor.tokenizer.padding_side = 'left'
            if hasattr(processor, 'text_tokenizer'):
                processor.text_tokenizer.padding_side = 'left'
            
            # 加载模型：使用基础模型的配置，但从训练后的模型加载权重
            model = AutoModelForImageTextToText.from_pretrained(
                weights_source_path,  # 从训练后的模型加载权重
                config=config,  # 使用基础模型的配置
                dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
            )
            model.eval()
            
            # 清理临时配置文件
            if temp_dir is not None:
                try:
                    shutil.rmtree(temp_dir)
                except Exception:
                    pass  # 忽略清理错误
            
            model_loaded_event.set()
            print(f"Agent {agent_id} 模型加载完成")
            
        except Exception as e:
            error_msg = f"Agent {agent_id} 模型加载失败: {e}"
            print(f"ERROR: {error_msg}")
            print(traceback.format_exc())
            status_queue.put((agent_id, "error", f"{e} at {traceback.format_exc()}"))
            return
        
        # 处理数据块（使用批处理加速）
        chunk.sort(key=itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(chunk, key=itemgetter("episode_id"))}
        
        # 收集所有需要处理的样本
        all_samples = []
        for episode_id, episode_records in trajs.items():
            episode_records.sort(key=itemgetter("step_id"))
            all_samples.extend(episode_records)
        
        # 批处理推理
        for batch_start in range(0, len(all_samples), batch_size):
            batch_end = min(batch_start + batch_size, len(all_samples))
            batch_samples = all_samples[batch_start:batch_end]
            
            # 准备批量数据
            batch_messages = []
            batch_records = []
            batch_indices = []
            
            for idx, sample in enumerate(batch_samples):
                record = copy.deepcopy(sample)
                batch_records.append(record)
                
                try:
                    # 从sample中提取messages和图片
                    messages, image_path = extract_messages_from_sample(sample, is_os_atlas_format=False)
                    
                    if not messages:
                        record.update({
                            "evaluation_error": "Missing messages field",
                            "predicted_response": "",
                            "predicted_qwen3": None,
                            "ground_truth_qwen3": None,
                            "is_type_match": False,
                            "is_success": False,
                            "click_distance": None,
                        })
                        batch_messages.append(None)  # 标记为无效
                        batch_indices.append(idx)
                        continue
                    
                    # 获取groundtruth
                    gt_action = sample.get("gt_action", {})
                    gt_qwen3 = convert_gt_action_to_qwen3(gt_action)
                    
                    if not gt_qwen3:
                        record.update({
                            "evaluation_error": "Missing gt_action field",
                            "predicted_response": "",
                            "predicted_qwen3": None,
                            "ground_truth_qwen3": None,
                            "is_type_match": False,
                            "is_success": False,
                            "click_distance": None,
                        })
                        batch_messages.append(None)  # 标记为无效
                        batch_indices.append(idx)
                        continue
                    
                    batch_messages.append(messages)
                    batch_indices.append(idx)
                    
                except Exception as e:
                    print(f"WARNING: 准备样本失败: {e}")
                    record.update({
                        "predicted_qwen3": None,
                        "is_type_match": False,
                        "is_success": False,
                        "evaluation_error": str(e)
                    })
                    batch_messages.append(None)
                    batch_indices.append(idx)
            
            # 过滤出有效的messages进行批量推理
            valid_indices = [i for i, msg in enumerate(batch_messages) if msg is not None]
            valid_messages = [batch_messages[i] for i in valid_indices]
            
            # 批量推理
            if valid_messages:
                try:
                    batch_responses = self.run_inference_batch(
                        model=model,
                        processor=processor,
                        messages_list=valid_messages,
                        max_new_tokens=max_new_tokens,
                    )
                except Exception as e:
                    print(f"WARNING: 批量推理失败: {e}")
                    print(traceback.format_exc())
                    batch_responses = [""] * len(valid_messages)
            else:
                batch_responses = []
            
            # 处理结果
            response_idx = 0
            for idx, record in enumerate(batch_records):
                if batch_messages[idx] is None:
                    # 无效样本，已经设置了错误信息
                    with list_lock:
                        pred_results.append(record)
                    with progress_lock:
                        progress_counter.value += 1
                    continue
                
                try:
                    # 获取对应的响应
                    if idx in valid_indices:
                        response = batch_responses[response_idx]
                        response_idx += 1
                    else:
                        response = ""
                    
                    # 获取groundtruth
                    sample = batch_samples[idx]
                    gt_action = sample.get("gt_action", {})
                    gt_qwen3 = convert_gt_action_to_qwen3(gt_action)
                    
                    # 解析模型输出
                    pred_qwen3 = parse_model_output_to_qwen3(response)
                    
                    # 评估
                    if pred_qwen3 and gt_qwen3:
                        is_type_match = is_qwen3_action_type_match(pred_qwen3, gt_qwen3)
                        is_success = is_qwen3_action_match(pred_qwen3, gt_qwen3, click_threshold)
                        click_distance = calculate_click_distance(pred_qwen3, gt_qwen3)
                    else:
                        is_type_match = False
                        is_success = False
                        click_distance = None
                    
                    record.update({
                        "predicted_response": response,
                        "predicted_qwen3": pred_qwen3,
                        "ground_truth_qwen3": gt_qwen3,
                        "gt_type": get_qwen3_action_type(gt_qwen3) if gt_qwen3 else 0,
                        "pred_type": get_qwen3_action_type(pred_qwen3) if pred_qwen3 else 0,
                        "is_type_match": is_type_match,
                        "is_success": is_success,
                    })
                    if click_distance is not None:
                        record["click_distance"] = click_distance
                    
                    with list_lock:
                        pred_results.append(record)
                    
                    with progress_lock:
                        progress_counter.value += 1
                        
                except Exception as e:
                    print(f"WARNING: 处理样本失败: {e}")
                    print(traceback.format_exc())
                    record.update({
                        "predicted_qwen3": None,
                        "is_type_match": False,
                        "is_success": False,
                        "evaluation_error": str(e)
                    })
                    
                    with list_lock:
                        pred_results.append(record)
                    
                    with progress_lock:
                        progress_counter.value += 1
    
    def _process_chunk_os_atlas(
        self,
        agent_id,
        device_ids,
        chunk,
        chunk_start,
        pred_results,
        list_lock,
        progress_counter,
        progress_lock,
        model_loaded_event,
        status_queue,
        model_path,
        max_new_tokens,
        click_threshold,
        model_name,
    ):
        """处理 OS-Atlas 格式的数据块"""
        import torch
        from PIL import Image
        
        # 创建参数对象
        class Args:
            def __init__(self):
                self.model_path = model_path
                self.model_name = model_name or "OS_Atlas-Pro-7B"
                self.thought = "false"
                self.probing_method = "none"
                self.dataset_type = "default"
                self.mask_object_ratio = 50
        
        args = Args()
        device = torch.device('cuda')
        
        # 加载 OS-Atlas agent
        try:
            agent = init_agent(args, device, True, args.model_name)
            model_loaded_event.set()
            print(f"Agent {agent_id} OS-Atlas 模型加载完成")
        except Exception as e:
            error_msg = f"Agent {agent_id} OS-Atlas 模型加载失败: {e}"
            print(f"ERROR: {error_msg}")
            print(traceback.format_exc())
            status_queue.put((agent_id, "error", f"{e} at {traceback.format_exc()}"))
            return
        
        # 初始化 OS-Atlas 结果预处理器
        os_atlas_preprocess = OS_ATLAS_RES_PRE_PROCESS()
        
        # 处理数据块
        chunk.sort(key=itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(chunk, key=itemgetter("episode_id"))}
        
        # 收集所有需要处理的样本
        all_samples = []
        for episode_id, episode_records in trajs.items():
            episode_records.sort(key=itemgetter("step_id"))
            all_samples.extend(episode_records)
        
        # 逐个处理样本（UI-TARS agent 不支持批量处理）
        for sample in all_samples:
            record = copy.deepcopy(sample)
            
            try:
                # 获取图片尺寸
                image_size = None
                images = sample.get("images", [])
                if images and len(images) > 0:
                    try:
                        img_path = images[0] if isinstance(images, list) else images
                        with Image.open(img_path) as img:
                            image_size = img.size  # (width, height)
                    except Exception:
                        image_size = (1080, 1920)  # 默认尺寸
                
                if image_size is None:
                    image_size = (1080, 1920)
                
                # 从 label 字段提取 groundtruth action
                label = sample.get("label", "")
                if not label:
                    record.update({
                        "evaluation_error": "Missing label field",
                        "predicted_response": "",
                        "predicted_qwen3": None,
                        "ground_truth_qwen3": None,
                        "is_type_match": False,
                        "is_success": False,
                        "click_distance": None,
                    })
                    with list_lock:
                        pred_results.append(record)
                    with progress_lock:
                        progress_counter.value += 1
                    continue
                
                # 提取 groundtruth action（OS-Atlas 格式）
                gt_action_os_atlas = os_atlas_preprocess.extract_action(label)
                
                # 转换为 qwen3 格式
                gt_qwen3 = convert_os_atlas_action_to_qwen3(gt_action_os_atlas, image_size)
                
                # 使用 OS-Atlas agent 获取预测
                obs = {
                    "messages": sample.get("messages", []),
                    "images": sample.get("images", []),
                    "label": label,
                    "image_size": image_size,
                }
                
                pred_action_raw = agent.get_action(obs, args)
                
                # 提取预测的 action（OS-Atlas 格式）
                pred_action_os_atlas = os_atlas_preprocess.extract_action(pred_action_raw) if pred_action_raw else ""
                
                # 转换为 qwen3 格式
                pred_qwen3 = convert_os_atlas_action_to_qwen3(pred_action_os_atlas, image_size) if pred_action_os_atlas else None
                
                # 评估
                if pred_qwen3 and gt_qwen3:
                    is_type_match = is_qwen3_action_type_match(pred_qwen3, gt_qwen3)
                    is_success = is_qwen3_action_match(pred_qwen3, gt_qwen3, click_threshold)
                    click_distance = calculate_click_distance(pred_qwen3, gt_qwen3)
                else:
                    is_type_match = False
                    is_success = False
                    click_distance = None
                
                record.update({
                    "predicted_response": pred_action_raw or "",
                    "predicted_qwen3": pred_qwen3,
                    "ground_truth_qwen3": gt_qwen3,
                    "gt_type": get_qwen3_action_type(gt_qwen3) if gt_qwen3 else 0,
                    "pred_type": get_qwen3_action_type(pred_qwen3) if pred_qwen3 else 0,
                    "is_type_match": is_type_match,
                    "is_success": is_success,
                })
                if click_distance is not None:
                    record["click_distance"] = click_distance
                
                with list_lock:
                    pred_results.append(record)
                
                with progress_lock:
                    progress_counter.value += 1
                    
            except Exception as e:
                print(f"WARNING: 处理样本失败: {e}")
                print(traceback.format_exc())
                record.update({
                    "predicted_qwen3": None,
                    "is_type_match": False,
                    "is_success": False,
                    "evaluation_error": str(e)
                })
                
                with list_lock:
                    pred_results.append(record)
                
                with progress_lock:
                    progress_counter.value += 1
        
        status_queue.put((agent_id, "success", len(all_samples)))
    
    def evaluate_mp(self, test_data: List[Dict]) -> List[Dict]:
        """多进程评估"""
        import torch.multiprocessing as mp
        ctx = mp.get_context("spawn")
        manager = mp.Manager()
        pred_results = manager.list([])
        progress_counter = manager.Value('i', 0)
        list_lock = manager.Lock()
        progress_lock = manager.Lock()
        status_queue = manager.Queue()
        
        # 过滤掉gt_action为null的样本（如果不是 OS-Atlas 格式）
        if not self.is_os_atlas_format:
            original_count = len(test_data)
            test_data = [sample for sample in test_data if sample.get("gt_action") is not None]
            filtered_count = len(test_data)
            print(f"过滤掉 {original_count - filtered_count} 个gt_action为null的样本，剩余: {filtered_count}")
        else:
            # OS-Atlas 格式：过滤掉 label 为空的样本
            original_count = len(test_data)
            test_data = [sample for sample in test_data if sample.get("label")]
            filtered_count = len(test_data)
            print(f"过滤掉 {original_count - filtered_count} 个label为空的样本，剩余: {filtered_count}")
        
        # 按 episode_id 分组
        test_data.sort(key=itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(test_data, key=itemgetter("episode_id"))}
        
        sorted_episode_ids = sorted(trajs.keys())
        chunk_size = math.ceil(len(trajs) / self.agent_count)
        chunks = [
            sorted_episode_ids[i:i + chunk_size]
            for i in range(0, len(sorted_episode_ids), chunk_size)
        ]
        
        chunks = [
            [trajs[episode_id] for episode_id in chunk] for chunk in chunks
        ]
        chunks = [
            [item for sublist in chunk for item in sublist]
            for chunk in chunks
        ]
        
        print(f"数据分片: 总共 {len(trajs)} 个episode, 分成 {len(chunks)} 个chunk")
        for i, chunk in enumerate(chunks):
            print(f"  Chunk {i}: {len(chunk)} 个样本")
        
        load_manager = mp.Manager()
        model_loaded_events = [load_manager.Event() for _ in range(self.agent_count)]
        
        processes = []
        for i, chunk in enumerate(chunks):
            chunk_start = i * chunk_size
            
            # 设置 CUDA 设备
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, self.device_ids[i]))
            
            p = ctx.Process(
                target=self.process_chunk,
                args=(
                    i, self.device_ids[i], chunk, chunk_start, pred_results,
                    list_lock, progress_counter, progress_lock,
                    model_loaded_events[i], status_queue, self.model_path,
                    self.base_model_path, self.max_new_tokens, self.click_threshold, self.batch_size,
                    self.is_os_atlas_format, self.model_name,
                ),
                daemon=False
            )
            p.start()
            processes.append(p)
        
        # 等待至少一个 agent 加载完成
        print("等待至少一个 agent 加载模型...")
        any_loaded = False
        for event in model_loaded_events:
            if event.wait(timeout=300):
                any_loaded = True
                break
        if not any_loaded:
            print("ERROR: 没有 agent 在超时时间内完成加载")
        else:
            print("至少一个 agent 加载完成，开始进度条")
        
        # 进度条
        with tqdm(total=len(test_data), desc=f"评估 {self.test_json_path}") as pbar:
            last_progress = 0
            while any(p.is_alive() for p in processes):
                with progress_lock:
                    current_progress = progress_counter.value
                
                if current_progress > last_progress:
                    pbar.update(current_progress - last_progress)
                    last_progress = current_progress
                import time
                time.sleep(0.5)
            
            for p in processes:
                p.join(timeout=1080000)
                if p.is_alive():
                    print(f"WARNING: Process {p.pid} 超时，终止")
                    p.terminate()
                    p.join()
            
            statuses = []
            while not status_queue.empty():
                statuses.append(status_queue.get())
            for sid, status, info in statuses:
                print(f"[Agent {sid}] Status: {status}, Info: {info}")
            
            failed_agents = [s for s in statuses if s[1] != "success"]
            if failed_agents:
                print(f"\nWARNING: {len(failed_agents)} agents 失败")
        
        all_pred_results = list(pred_results)
        all_pred_results.sort(key=itemgetter("episode_id", "step_id"))
        
        return all_pred_results
    
    def _compute_single_metrics(self, results: List[Dict], type_to_name: Dict[int, str]) -> Dict:
        """计算单个结果集的评估指标"""
        # 初始化计数器
        action_counts = {name: 0 for name in type_to_name.values()}
        action_counts["TOTAL"] = 0
        
        action_success = {name: 0 for name in type_to_name.values()}
        action_success["TOTAL"] = 0
        
        action_type_match = {name: 0 for name in type_to_name.values()}
        action_type_match["TOTAL"] = 0
        
        episode_count = 0
        success_episode_count = 0
        
        # 按 episode 分组（如果有episode_id字段）
        has_episode_id = any("episode_id" in r for r in results) if results else False
        
        if has_episode_id:
            results.sort(key=itemgetter("episode_id", "step_id") if "step_id" in results[0] else itemgetter("episode_id"))
            trajs = {k: list(v) for k, v in groupby(results, key=itemgetter("episode_id"))}
            
            for episode_id, episode_records in trajs.items():
                episode_count += 1
                episode_success = True
                
                for record in episode_records:
                    is_success = record.get("is_success", False)
                    is_type_match = record.get("is_type_match", False)
                    gt_type = record.get("gt_type", 0)
                    
                    # 获取动作名称（如果类型为 0 或不在映射中，跳过）
                    if gt_type in type_to_name:
                        action_name = type_to_name[gt_type]
                        action_counts[action_name] += 1
                        if is_success:
                            action_success[action_name] += 1
                        if is_type_match:
                            action_type_match[action_name] += 1
                    
                    # 统计总数（包括所有类型，包括未知类型）
                    action_counts["TOTAL"] += 1
                    if is_success:
                        action_success["TOTAL"] += 1
                    if is_type_match:
                        action_type_match["TOTAL"] += 1
                    
                    if not is_success:
                        episode_success = False
                
                if episode_success:
                    success_episode_count += 1
        else:
            # 没有episode_id，按样本统计
            for record in results:
                is_success = record.get("is_success", False)
                is_type_match = record.get("is_type_match", False)
                gt_type = record.get("gt_type", 0)
                
                # 获取动作名称（如果类型为 0 或不在映射中，跳过）
                if gt_type in type_to_name:
                    action_name = type_to_name[gt_type]
                    action_counts[action_name] += 1
                    if is_success:
                        action_success[action_name] += 1
                    if is_type_match:
                        action_type_match[action_name] += 1
                
                # 统计总数（包括所有类型，包括未知类型）
                action_counts["TOTAL"] += 1
                if is_success:
                    action_success["TOTAL"] += 1
                if is_type_match:
                    action_type_match["TOTAL"] += 1
        
        # 计算 TMR 和 AMR
        metrics = {}
        all_action_names = list(type_to_name.values()) + ["TOTAL"]
        
        for action_name in all_action_names:
            count = action_counts[action_name]
            tmr = action_type_match[action_name] / count if count > 0 else 0
            amr = action_success[action_name] / count if count > 0 else 0
            
            metrics[action_name] = {
                "count": count,
                "type_match": action_type_match[action_name],
                "action_match": action_success[action_name],
                "TMR": round(tmr, 4),
                "AMR": round(amr, 4)
            }
        
        # Episode 成功率
        if has_episode_id:
            episode_success_rate = success_episode_count / episode_count if episode_count > 0 else 0
            metrics["episode_success_rate"] = round(episode_success_rate, 4)
            metrics["episode_count"] = episode_count
            metrics["success_episode_count"] = success_episode_count
        
        return metrics
    
    def compute_metrics(self, all_results: List[Dict]) -> Dict:
        """计算评估指标（使用 Qwen3 动作空间），包括按 data_type 分类统计"""
        # 动作类型编号到名称的映射（Qwen3 动作空间）
        type_to_name = {
            1: "click",           # click
            2: "type",             # type / answer
            3: "swipe",            # swipe (SCROLL)
            4: "system_button",    # system_button (所有 system_button 统一统计)
            5: "terminate",        # terminate (COMPLETE)
            6: "wait",             # wait
            7: "long_press",       # long_press
        }
        
        # 整体统计
        overall_metrics = self._compute_single_metrics(all_results, type_to_name)
        
        # 按 data_type 分组统计
        results_by_datatype = defaultdict(list)
        for item in all_results:
            data_type = item.get("data_type", -1)  # 如果没有 data_type，使用 -1
            results_by_datatype[data_type].append(item)
        
        # 计算每个 data_type 的统计
        metrics_by_datatype = {}
        for data_type in sorted(results_by_datatype.keys()):
            type_results = results_by_datatype[data_type]
            type_metrics = self._compute_single_metrics(type_results, type_to_name)
            metrics_by_datatype[f"data_type_{data_type}"] = type_metrics
        
        # 返回包含整体和分类统计的结果
        return {
            "overall": overall_metrics,
            "by_datatype": metrics_by_datatype
        }
    
    def evaluate(self):
        """主评估函数"""
        # 加载测试数据
        print(f"加载测试数据: {self.test_json_path}")
        with open(self.test_json_path, 'r', encoding='utf-8') as f:
            test_data = json.load(f)
        
        print(f"测试数据总数: {len(test_data)}")
        
        # 评估
        print("开始多进程评估...")
        all_pred_results = self.evaluate_mp(test_data)
        
        print(f"评估完成，共 {len(all_pred_results)} 条结果")
        
        # 计算指标
        print("计算评估指标...")
        metrics = self.compute_metrics(all_pred_results)
        
        # 保存结果
        print(f"保存结果到: {self.result_save_path}")
        result_dir = os.path.dirname(self.result_save_path)
        os.makedirs(result_dir, exist_ok=True)
        
        output_data = {
            "metrics": metrics,
            "config": {
                "model_path": self.model_path,
                "max_new_tokens": self.max_new_tokens,
                "click_threshold": self.click_threshold,
                "action_space": "qwen3",
                "action_types": {
                    "1": "click",
                    "2": "type",
                    "3": "swipe",
                    "4": "system_button",
                    "5": "terminate",
                    "6": "wait",
                    "7": "long_press"
                }
            },
            "detailed_results": all_pred_results
        }
        
        with open(self.result_save_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)
        
        # 保存按 data_type 分类的统计结果
        statistics_output = {}
        
        # 先添加按 data_type 分类的统计
        for data_type_key, type_metrics in metrics["by_datatype"].items():
            statistics_output[data_type_key] = {
                "metrics": {k: v for k, v in type_metrics.items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
                "episode_success_rate": type_metrics.get("episode_success_rate", 0),
                "episode_count": type_metrics.get("episode_count", 0),
                "success_episode_count": type_metrics.get("success_episode_count", 0)
            }
        
        # 最后添加整体统计
        statistics_output["overall"] = {
            "metrics": {k: v for k, v in metrics["overall"].items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
            "episode_success_rate": metrics["overall"].get("episode_success_rate", 0),
            "episode_count": metrics["overall"].get("episode_count", 0),
            "success_episode_count": metrics["overall"].get("success_episode_count", 0)
        }
        
        statistics_file = self.result_save_path.replace('.json', '_statistics_by_datatype.json')
        with open(statistics_file, 'w', encoding='utf-8') as f:
            json.dump(statistics_output, f, ensure_ascii=False, indent=2)
        
        # 打印指标
        print("\n" + "=" * 60)
        print("评估结果")
        print("=" * 60)
        
        # 统计信息
        total = len(all_pred_results)
        success_count = sum(1 for r in all_pred_results if r.get("is_success", False))
        type_match_count = sum(1 for r in all_pred_results if r.get("is_type_match", False))
        error_count = sum(1 for r in all_pred_results if "evaluation_error" in r)
        
        print(f"总样本数: {total}")
        print(f"成功评估: {total - error_count} ({100*(total-error_count)/total:.2f}%)")
        print(f"失败: {error_count} ({100*error_count/total:.2f}%)")
        
        # 打印整体统计
        overall_metrics = metrics["overall"]
        print(f"\n--- 整体动作类型指标（Qwen3 动作空间）---")
        for action_name in ["click", "type", "swipe", "system_button", "terminate", "wait", "long_press", "TOTAL"]:
            m = overall_metrics.get(action_name, {})
            if m.get("count", 0) > 0:
                print(f"{action_name}: Count={m['count']}, TMR={m['TMR']:.4f}, AMR={m['AMR']:.4f}")
        
        if "episode_success_rate" in overall_metrics:
            print(f"\n整体 Episode 成功率: {overall_metrics['episode_success_rate']:.4f} ({overall_metrics['success_episode_count']}/{overall_metrics['episode_count']})")
        
        # 打印按 data_type 分类的统计
        if metrics["by_datatype"]:
            print(f"\n--- 按 Data Type 分类统计 ---")
            # 按 data_type 分组计算样本数量
            results_by_datatype = defaultdict(list)
            for item in all_pred_results:
                data_type = item.get("data_type", -1)
                results_by_datatype[data_type].append(item)
            
            for data_type_key in sorted(metrics["by_datatype"].keys()):
                type_metrics = metrics["by_datatype"][data_type_key]
                data_type_num = data_type_key.replace("data_type_", "")
                data_type_int = int(data_type_num) if data_type_num.lstrip('-').isdigit() else -1
                sample_count = len(results_by_datatype.get(data_type_int, []))
                
                print(f"\nData Type {data_type_num}:")
                print(f"  样本数量: {sample_count}")
                total_metric = type_metrics.get("TOTAL", {})
                if total_metric.get("count", 0) > 0:
                    print(f"  动作总数: {total_metric['count']}")
                    print(f"  类型匹配率 (TMR): {total_metric['TMR']:.4f}")
                    print(f"  动作匹配率 (AMR): {total_metric['AMR']:.4f}")
                if "episode_success_rate" in type_metrics:
                    print(f"  Episode成功率: {type_metrics['episode_success_rate']:.4f}")
                    print(f"  Episode总数: {type_metrics['episode_count']}")
                    print(f"  成功Episode数: {type_metrics['success_episode_count']}")
                print(f"  各动作类型详细统计:")
                for action_name in ["click", "type", "swipe", "system_button", "terminate", "wait", "long_press"]:
                    m = type_metrics.get(action_name, {})
                    if m.get("count", 0) > 0:
                        print(f"    {action_name}: count={m['count']}, TMR={m['TMR']:.4f}, AMR={m['AMR']:.4f}")
        
        print(f"\n输出文件: {self.result_save_path}")
        print(f"统计文件: {statistics_file}")
        
        return metrics


def main():
    parser = argparse.ArgumentParser(description='Evaluate test.json using chunked parallel processing (no vLLM)')
    parser.add_argument('--test_json_path', type=str, default="/home/chengpengzhou/hhw/rft_data/faithful_dataset/test_0102.json",
                       help='测试数据路径')
    parser.add_argument('--result_save_path', type=str, default="/home/chengpengzhou/hhw/rft_data/faithful_dataset/results_new/results_grpo_faithful_0113_3ep_base.json",
                       help='结果保存路径')
    # parser.add_argument('--model_path', type=str, default="/data/cpz/hhw/experiments/grpo_faithful_0102/model",
    #                    help='模型路径（已包含基础模型配置和GRPO训练后的权重）')
    parser.add_argument('--base_model_path', type=str, default=None,
                       help='基础模型路径（用于加载配置文件，如果提供则使用基础模型的配置+训练后模型的权重，通常不需要）')
    parser.add_argument('--model_path', type=str, default="/data/cpz/hhw/experiments/grpo_faithful_0113_3ep_base/model",
                       help='模型路径')
    parser.add_argument('--max_new_tokens', type=int, default=512,
                       help='最大生成 token 数')
    parser.add_argument('--click_threshold', type=float, default=140.0,
                       help='CLICK 动作的距离阈值（像素）')
    parser.add_argument('--device_ids', type=str, default='[2,2,3,3,4,4,5,5,6,6,7,7]',
                       help='CUDA 设备 ID 列表，如 "[0]" 或 "[0,1]"')
    parser.add_argument('--agent_count', type=int, default=12,
                       help='并行 agent 数量')
    parser.add_argument('--batch_size', type=int, default=8,
                       help='HF推理批处理大小（用于加速）')
    parser.add_argument('--model_name', type=str, default=None,
                       help='模型名称（用于判断是否使用 UI-TARS agent）')
    
    args = parser.parse_args()
    
    # 创建评估器
    evaluator = TestDataEvaluator(args)
    
    # 运行评估
    metrics = evaluator.evaluate()
    
    return metrics


if __name__ == "__main__":
    main()
