import os
import json
import glob
import argparse
import time  # 导入时间模块用于计时
from llm_eval import LLMEvaluator
from utils import get_instruction, load_data_items, prepare_result_filenames, load_existing_results, compute_and_save_metrics
from utils import parse_action_reward, compute_binary_classification_metrics, analyze_group_performance, parse_llm_output
from model_config import *
import re


# 命令行参数解析
parser = argparse.ArgumentParser(description="Evaluate cuarewardbench trajectories")
parser.add_argument("--eval_mode", default="sewsm", help="evaluation mode (zerogui, sewsm, opencua_reflect, opencua_fulltraj)")
parser.add_argument("--model_config", default="qwen2p5vl_32b_tione_config", help="in src/eval_traj/model_config.py")
parser.add_argument("--num_items", type=int, default=0, help="Number of items to evaluate. Set to 0 to evaluate all.")
parser.add_argument("--only_metric", action="store_true", help="not evaluate, only compute metric")
parser.add_argument("--temperature", type=float, default=0.0, help="LLM temperature")
parser.add_argument("--img_scale", type=float, default=1.0, help="Image scale factor (default: 1.0)")
parser.add_argument("--output_dir", default="./", help="Output directory")
parser.add_argument("--max_screenshots", type=int, default=20, help="Maximum number of screenshots to use per task (0 for all)")
parser.add_argument("--exp_suffix", default="", help="Output directory")
parser.add_argument("--reward_type", default="or_ar", help="reward type: or, ar, or_ar")

args = parser.parse_args()

if args.eval_mode == "zerogui":
    args.reward_type = "or"
elif args.eval_mode == "opencua_reflect":
    args.reward_type = "ar"
elif args.eval_mode == "opencua_fulltraj":
    args.reward_type = "ar"
args.prompt_file = f"osworld_{args.eval_mode}.json"

# 基础路径
BASE_DIR = "data/osworld-verified_traces"
JSONL_PATH = "data/cuarewardbench/cuarewardbench-v0.4.json"
PROMPT_DIR = "/Users/xxx/Documents/codes/cuarewardbench/config/prompt"


model_config = eval(args.model_config)
args.model_config = model_config


def evaluate_cuarewardbench():
    """评估cuarewardbench数据集"""
    # 初始化评估器
    evaluator = LLMEvaluator(
        **model_config,
        prompt_file=args.prompt_file,
        prompt_dir=PROMPT_DIR
    )
    
    # 加载数据项
    data_items = load_data_items(JSONL_PATH, args)
        
    # 准备结果文件路径
    detailed_filepath, metrics_filename = prepare_result_filenames(JSONL_PATH, args)
    
    # 加载已有结果
    existing_results, evaluated_tasks = load_existing_results(detailed_filepath)
    
    # 处理任务评估
    all_results = existing_results
    if not args.only_metric:
        results = process_tasks(
            evaluator, data_items, existing_results, evaluated_tasks,
            detailed_filepath
        )
        all_results = existing_results + results
    # 从llm_output中解析出traj-level reward和action reward
    all_results = parse_llm_output(all_results, evaluator, args)
    # 计算并保存指标
    compute_and_save_metrics(
        all_results,
        detailed_filepath, metrics_filename, args
    )


def process_tasks(evaluator, data_to_process, existing_results, evaluated_tasks, 
                 detailed_filepath):
    """处理所有任务评估"""
    results = []
    total_tasks = len(data_to_process)
    current_index = 0
    
    for data in data_to_process:
        current_index += 1
        # 使用黄色加粗文本使进度信息更醒目
        print(f"\033[1;33mProcessing task {current_index}/{total_tasks}\033[0m")
        model_setting = data["model_setting"]
        task_type = data["task_type"]
        task_id = data["task_id"]
        
        # 检查是否已评估过
        task_key = (model_setting, task_id)
        if task_key in evaluated_tasks:
            print(f"Skipping already evaluated task: {model_setting}/{task_id}")
            continue
        
        # 构建测试轨迹
        test_trajectory = prepare_test_trajectory(model_setting, task_type, task_id)
        
        # 构建任务配置
        task_config = {
            "instruction": get_instruction(task_type, task_id)
        }
        
        # 评估单个任务
        result = evaluate_single_task(
            evaluator, task_config, test_trajectory, 
            model_setting, task_type, task_id, data
        )
        
        if result:
            results.append(result)
            # 增量保存结果
            save_incremental_results(detailed_filepath, existing_results + results)

    return results

def prepare_test_trajectory(model_setting, task_type, task_id):
    """准备测试轨迹（截图列表）"""
    task_dir = os.path.join(BASE_DIR, model_setting, task_type, task_id)
    screenshot_files = glob.glob(os.path.join(task_dir, "step_*.png"))
    screenshot_files.sort(key=lambda x: int(os.path.basename(x).split('_')[1]))

    task_dir_marker = os.path.join(f"{BASE_DIR}_w_action", model_setting, task_type, task_id)
    screenshot_files_marker = glob.glob(os.path.join(task_dir_marker, "step_*.png"))
    screenshot_files_marker.sort(key=lambda x: int(os.path.basename(x).split('_')[1]))

    # 限制最大截图数量
    if args.max_screenshots > 0:
        screenshot_files = screenshot_files[-args.max_screenshots:]

    # 查找轨迹文件
    traj_files = glob.glob(os.path.join(task_dir, "traj*.jsonl"))
    if not traj_files:
        print(f"警告: 在目录 {task_dir} 中没有找到轨迹文件")
        raise ValueError("轨迹文件不存在")
    
    actions = []
    traj_file = traj_files[0]
    # 读取轨迹文件
    with open(traj_file, 'r', encoding='utf-8') as f:
        jsonl_lines = f.readlines()

    # 处理每个action，将其绘制到下一个步骤的图片上
    for i in range(len(jsonl_lines)):  
        try:
            data = json.loads(jsonl_lines[i])
            action_text = data.get('action', '')
            if isinstance(action_text, dict):
                action_text = action_text.get('command', '')
            
            # import pdb; pdb.set_trace()
            # 使用正则表达式移除所有'''包裹的内容
            action_text = re.sub(r"'''.*?'''", "", action_text, flags=re.DOTALL)
            action_text = action_text.replace('import pyautogui', '').replace("import time","").replace("\n\n","").replace("\n\n","")
            # 移除以#开头到行尾的注释
            action_text = re.sub(r"#.*?\n", "", action_text)
            actions.append(action_text)
        except json.JSONDecodeError:
            print(f"JSON解析错误，跳过步骤 {i+1}")

    test_traj = {
        "actions": actions,
        "screenshots": screenshot_files,
        "screenshots_marker": screenshot_files_marker
        }
    return test_traj

def evaluate_single_task(evaluator, task_config, test_trajectory, 
                        model_setting, task_type, task_id, data):
    """评估单个任务"""
    gt_result = int(data["result"])
    step_num = int(data["step_num"])
    comments = data["comments"]
    instruction = data["instruction"]
    categories = [
        "key_bad_in_fail",
        "key_good_before_bad_in_fail",
        "key_bad_in_success",
        "key_good_after_bad_in_success",
        "key_good_in_success",
        "redundant_in_success",
    ]
    for category in categories:
        if category not in data:
            data[category] = []

    action_annos = {
        "key_bad_in_fail" : data["key_bad_in_fail"],
        "key_good_before_bad_in_fail": data["key_good_before_bad_in_fail"],
        "key_good_in_success": data["key_good_in_success"],
        "key_bad_in_success": data["key_bad_in_success"],
        "key_good_after_bad_in_success": data["key_good_after_bad_in_success"],
        "redundant_in_success": data["redundant_in_success"],
    }
    # 将所有action_annos中的列表合并成一个列表
    action_annos_list = []
    for action_list in action_annos.values():
        if isinstance(action_list, list):
            action_annos_list.extend(action_list)

    try:
        start_time = time.time()
        print(f'==================="Evaluated {task_id}, {model_setting} ========================')
        eval_output_dict = dict()
        if args.eval_mode in ["opencua_reflect","opencua_fulltraj"]:
            for action in action_annos_list:
                eval_output = evaluator.evaluate_task_step(task_config, test_trajectory, action, eval_mode=args.eval_mode, img_scale=args.img_scale)
                action = str(action)
                eval_output_dict[action] = eval_output["llm_output"]
            llm_output = eval_output_dict
        elif args.eval_mode in ["sewsm_targetar"]:
            for action in action_annos_list:
                eval_output = evaluator.evaluate_task(task_config, test_trajectory, action, eval_mode=args.eval_mode, img_scale=args.img_scale)
                action = str(action)
                eval_output_dict[action] = eval_output["llm_output"]
            llm_output = eval_output_dict
        else:
            eval_output = evaluator.evaluate_task(task_config, test_trajectory, eval_mode=args.eval_mode, img_scale=args.img_scale)
            llm_output = eval_output.get("llm_output", "")
        eval_time = time.time() - start_time
        
        # 构建结果记录
        result = {
            "task_id": task_id,
            "task_type": task_type,
            "model_setting": model_setting,
            "reward_fn": model_config['model'],
            "step_num": step_num,
            "gt": gt_result,
            "pred": -2,
            "instruction": instruction,
            "llm_output": llm_output,
            "action_annos": action_annos,
            "action_reward_list": [],
            "eval_time": eval_time,
            "comments": comments
        }
        
        print(llm_output)
        print(f"Evaluated {task_id}: GT={gt_result}, Time={eval_time:.2f}s")
        print(f"========================================================================")
        
        return result
    
    except Exception as e:
        print(f"Error evaluating task {task_id}: {str(e)}")
        return None, None

def save_incremental_results(filepath, results):
    """增量保存结果到文件"""
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
    except Exception as e:
        print(f"Error saving incremental results: {str(e)}")


if __name__ == "__main__":
    evaluate_cuarewardbench()