#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
批量评估多个模型
"""
import json
import os
import sys
import subprocess
sys.path.append('./')
from utils.logging_utils import setup_logger_to_stdout

logger = setup_logger_to_stdout()

# 模型配置
MODEL_CONFIGS = [
    # {
    #     "model_name": "GUI-Owl-7B",
    #     "model_path": "/data/models/GUI-Owl-7B",
    #     "dataset_path": "/home/chengpengzhou/hhw/rft_data/baseline/data/test_gui_owl.json",
    #     "result_dir": "/home/chengpengzhou/hhw/rft_data/baseline/results"
    # },
    # {
    #     "model_name": "OS-Atlas-Pro-7B",
    #     "model_path": "/data/models/OS-Atlas-Pro-7B",
    #     "dataset_path": "/home/chengpengzhou/hhw/rft_data/baseline/data/test_os_atlas.json",
    #     "result_dir": "/home/chengpengzhou/hhw/rft_data/baseline/results"
    # },
    # {
    #     "model_name": "Qwen2.5-VL-7B-Instruct",
    #     "model_path": "/data/models/Qwen2.5-VL-7B-Instruct",
    #     "dataset_path": "/home/chengpengzhou/hhw/rft_data/baseline/data/test_gui_owl.json",  # 使用gui-owl的数据
    #     "result_dir": "/home/chengpengzhou/hhw/rft_data/baseline/results"
    # },
    {
        "model_name": "UI-TARS-1.5-7B",
        "model_path": "/data/models/UI-TARS-1.5-7B",
        "dataset_path": "/home/chengpengzhou/hhw/rft_data/baseline/data/test_ui_tars.json",
        "result_dir": "/home/chengpengzhou/hhw/rft_data/baseline/results"
    }
]

def run_evaluation(config, device_ids="[0]", num_process=1):
    """
    运行单个模型的评估
    """
    model_name = config["model_name"]
    model_path = config["model_path"]
    dataset_path = config["dataset_path"]
    result_dir = config["result_dir"]
    
    # 创建结果目录
    os.makedirs(result_dir, exist_ok=True)
    
    # 生成结果文件名（不使用时间戳）
    result_filename = f"{model_name.replace('-', '_').replace('.', '_')}.json"
    result_path = os.path.join(result_dir, result_filename)
    
    logger.info("=" * 80)
    logger.info(f"开始评估模型: {model_name}")
    logger.info(f"模型路径: {model_path}")
    logger.info(f"数据集路径: {dataset_path}")
    logger.info(f"结果保存路径: {result_path}")
    logger.info("=" * 80)
    
    # 检查模型路径是否存在
    if not os.path.exists(model_path):
        logger.error(f"模型路径不存在: {model_path}")
        return False
    
    # 检查数据集路径是否存在
    if not os.path.exists(dataset_path):
        logger.error(f"数据集路径不存在: {dataset_path}")
        return False
    
    # 构建评估命令
    # 根据initial_agent.py的匹配逻辑调整agent名称
    # - "OS_Atlas" in modelName -> 需要包含 "OS_Atlas"（下划线）
    # - "UI-TARS" in modelName -> 需要包含 "UI-TARS"
    # - "GUI-Owl" in modelName -> 需要包含 "GUI-Owl"
    if "Qwen2.5-VL" in model_name or "Qwen2.5" in model_name:
        agent_model_name = "GUI-Owl-7B"  # 使用GUI-Owl的agent来处理（相同数据格式）
    elif "OS-Atlas" in model_name or "OS_Atlas" in model_name:
        agent_model_name = "OS_Atlas-Pro-7B"  # 确保包含OS_Atlas（下划线）
    else:
        agent_model_name = model_name  # UI-TARS-1.5-7B和GUI-Owl-7B应该可以直接匹配
    
    cmd = [
        sys.executable,
        "evaluate_memory_vs_reasoning_mp.py",
        "--model_path", model_path,
        "--model_name", agent_model_name,
        "--dataset_path", dataset_path,
        "--result_path", result_path,
        "--num_process", str(num_process),
        "--deviceIds", device_ids
    ]
    
    logger.info(f"执行命令: {' '.join(cmd)}")
    
    try:
        # 运行评估，实时显示输出
        logger.info(f"开始运行评估，输出将实时显示...")
        process = subprocess.Popen(
            cmd,
            cwd=os.path.dirname(os.path.abspath(__file__)),
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1  # 行缓冲
        )
        
        # 实时输出
        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                sys.stdout.write(output)
                sys.stdout.flush()
        
        returncode = process.poll()
        
        if returncode == 0:
            logger.info(f"✓ {model_name} 评估完成")
            logger.info(f"结果保存在: {result_path}")
            return True
        else:
            logger.error(f"✗ {model_name} 评估失败，返回码: {returncode}")
            return False
            
    except subprocess.TimeoutExpired:
        logger.error(f"✗ {model_name} 评估超时")
        return False
    except Exception as e:
        logger.error(f"✗ {model_name} 评估出错: {e}")
        return False

def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='批量评估多个模型')
    parser.add_argument('--deviceIds', type=str, default="[0]",
                        help='GPU设备ID列表，例如 [0] 或 [0,1,2,3]')
    parser.add_argument('--num_process', type=int, default=None,
                        help='并行进程数。如果不指定，将根据processes_per_gpu自动计算')
    parser.add_argument('--processes_per_gpu', type=int, default=3,
                        help='每个GPU卡运行的进程数（默认3）')
    parser.add_argument('--models', type=str, nargs='+', default=None,
                        help='指定要评估的模型名称列表，如果不指定则评估所有模型')
    parser.add_argument('--skip_existing', action='store_true',
                        help='跳过已有结果的模型')
    
    args = parser.parse_args()
    
    # 如果没有指定num_process，根据processes_per_gpu自动计算
    if args.num_process is None:
        try:
            device_list = eval(args.deviceIds)
            args.num_process = len(device_list) * args.processes_per_gpu
            logger.info(f"自动计算进程数: {len(device_list)} 个GPU × {args.processes_per_gpu} 进程/GPU = {args.num_process} 个进程")
        except:
            logger.warning("无法解析deviceIds，使用默认进程数1")
            args.num_process = 1
    
    # 过滤要评估的模型
    configs_to_run = MODEL_CONFIGS
    if args.models:
        configs_to_run = [c for c in MODEL_CONFIGS if c["model_name"] in args.models]
        if not configs_to_run:
            logger.error(f"未找到指定的模型: {args.models}")
            logger.info(f"可用的模型: {[c['model_name'] for c in MODEL_CONFIGS]}")
            return
    
    logger.info(f"准备评估 {len(configs_to_run)} 个模型")
    try:
        device_list = eval(args.deviceIds)
        logger.info(f"使用设备: {args.deviceIds} ({len(device_list)} 个GPU), 总进程数: {args.num_process} (每个GPU约 {args.num_process // len(device_list)} 个进程)")
    except:
        logger.info(f"使用设备: {args.deviceIds}, 总进程数: {args.num_process}")
    
    results = {}
    for i, config in enumerate(configs_to_run, 1):
        logger.info(f"\n[{i}/{len(configs_to_run)}] 处理模型: {config['model_name']}")
        
        # 检查是否跳过已有结果
        if args.skip_existing:
            result_dir = config["result_dir"]
            if os.path.exists(result_dir):
                existing_files = [f for f in os.listdir(result_dir) 
                                if f.startswith(config["model_name"].replace("-", "_").replace(".", "_"))]
                if existing_files:
                    logger.info(f"跳过 {config['model_name']}，已有结果文件: {existing_files[0]}")
                    results[config['model_name']] = "skipped"
                    continue
        
        success = run_evaluation(config, args.deviceIds, args.num_process)
        results[config['model_name']] = "success" if success else "failed"
    
    # 打印总结
    logger.info("\n" + "=" * 80)
    logger.info("评估总结")
    logger.info("=" * 80)
    for model_name, status in results.items():
        status_symbol = "✓" if status == "success" else ("⊘" if status == "skipped" else "✗")
        logger.info(f"{status_symbol} {model_name}: {status}")
    
    # 统计
    success_count = sum(1 for s in results.values() if s == "success")
    failed_count = sum(1 for s in results.values() if s == "failed")
    skipped_count = sum(1 for s in results.values() if s == "skipped")
    
    logger.info(f"\n总计: {len(results)} 个模型")
    logger.info(f"成功: {success_count}, 失败: {failed_count}, 跳过: {skipped_count}")

if __name__ == "__main__":
    main()
