import json
import os
import ast
from collections import defaultdict


def get_instruction(task_type, task_id):
    """根据任务类型和ID获取指令内容
    
    Args:
        task_type (str): 任务类型（如'chrome'）
        task_id (str): 任务ID（如'0d8b7de3-e8de-4d86-b9fd-dd2dce58a217'）
    
    Returns:
        str: 从JSON文件中读取的instruction字段内容
    """
    # JSON文件基础路径
    base_path = "/Users/xxx/Documents/codes/cuarewardbench/data/OSWorld/evaluation_examples/examples"
    
    # 构建完整的JSON文件路径
    json_path = os.path.join(base_path, task_type, f"{task_id}.json")
    
    # 读取JSON文件内容
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            instruction = data.get("instruction", "")
    except FileNotFoundError as e:
        raise FileNotFoundError(f"未找到JSON文件: {json_path}") from e
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"JSON文件解析失败: {json_path}", e.doc, e.pos) from e
    
    return instruction

def load_data_items(JSONL_PATH, args):
    """加载数据项"""
    data_items = []
    if JSONL_PATH.endswith('.jsonl'):
        # JSONL格式：每行一个JSON对象
        with open(JSONL_PATH, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data_items.append(json.loads(line))
    elif JSONL_PATH.endswith('.json'):
        # JSON格式：整个文件是一个JSON数组
        with open(JSONL_PATH, 'r', encoding='utf-8') as f:
            data_items = json.load(f)
            if not isinstance(data_items, list):
                raise ValueError(f"JSON file should contain a list, got {type(data_items)}")
    else:
        raise ValueError(f"Unsupported file format: {JSONL_PATH}")
    
    # 限制处理的项目数量
    return data_items if args.num_items <= 0 else data_items[:args.num_items]

def prepare_result_filenames(JSONL_PATH, args):
    """准备结果文件路径"""
    jsonl_basename = os.path.splitext(os.path.basename(JSONL_PATH))[0]
    model_name = args.model_config["model"].replace('/', '_')
    
    exp_name = f"evalmode_{args.eval_mode}_{jsonl_basename}_{model_name}"
    if args.max_screenshots != 20:
        exp_name = f"{exp_name}_maxS{args.max_screenshots}"
    if args.exp_suffix:
        exp_name = f"{exp_name}_{args.exp_suffix}"
    detailed_filename = f"detailed_results_{exp_name}.json"
    detailed_filepath = os.path.join(args.output_dir, detailed_filename)
    
    metrics_filename = f"metrics_{exp_name}.json"
    
    return detailed_filepath, metrics_filename

def load_existing_results(detailed_filepath):
    """加载已有的详细结果"""
    existing_results = []
    evaluated_tasks = set()
    
    if os.path.exists(detailed_filepath):
        with open(detailed_filepath, 'r', encoding='utf-8') as f:
            existing_results = json.load(f)
        print(f"Loaded existing results from {os.path.basename(detailed_filepath)}")
        
        # 构建已评估任务的集合
        for res in existing_results:
            if isinstance(res, dict) and 'model_setting' in res and 'task_id' in res:
                key = (res['model_setting'], res['task_id'])
                evaluated_tasks.add(key)

    return existing_results, evaluated_tasks

def compute_binary_classification_metrics(true_labels, pred_labels):
    """计算二分类指标：精确率、召回率、F1、总体准确率等"""
    tp = fp = fn = tn = 0
    for true, pred in zip(true_labels, pred_labels):
        if true == 1 and pred == 1:
            tp += 1
        elif true == 0 and pred == 1:
            fp += 1
        elif true == 1 and pred == 0:
            fn += 1
        else:
            tn += 1
    
    total_samples = len(true_labels)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    oa = (tp + tn) / total_samples
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'oa': oa,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'tn': tn,
        'total_samples': total_samples
    }

def parse_action_reward(all_results):
    """解析action reward并收集good_steps和bad_steps对应的标签数据"""
    all_good_true_labels = []  # 收集good_steps的真实标签 (全为1)
    all_good_pred_labels = []  # 收集good_steps的预测标签
    all_bad_true_labels = []   # 收集bad_steps的真实标签 (全为0)
    all_bad_pred_labels = []   # 收集bad_steps的预测标签
    
    for task in all_results:
        res_dict = task["res_dict"]
        action_annos = task.get('action_annos', None)
        pred_wrong_steps = []
        
        # 生成预测标签列表
        action_reward_list = [1] * task["step_num"]
        if res_dict.get("First_Error_Step", None):
            s = res_dict["First_Error_Step"] - 1
            for i in range(s, len(action_reward_list)):
                action_reward_list[i] = 0

        Redundant = res_dict.get("Redundant", [])
        for step in Redundant:
            try:
                if step <= task["step_num"]:
                    action_reward_list[step-1] = 0
            except Exception as e:
                import pdb; pdb.set_trace()
                print(f"Error parsing action_reward_list: {str(e)}")

        # 遍历res_dict中的每个步骤，检查Target_Step_Correct和Target_Step_Redundant
        for step_str, step_info in res_dict.items():
            if not step_str.isdigit():
                break
            step = int(step_str)
            # import pdb; pdb.set_trace()
            if not step_info.get("Target_Step_Correct", True) or step_info.get("Target_Step_Redundant", False):
                action_reward_list[step-1] = 0
        
        task["action_reward_list"] = action_reward_list
        task['action_reward'] = {}  # 清空不再使用的字段
        
        if action_annos:
            step_num = task["step_num"]
            
            # 获取good_steps和bad_steps
            good_steps = (
                action_annos.get('key_good_before_bad_in_fail', []) +
                action_annos.get('key_good_in_success', []) +
                action_annos.get('key_good_after_bad_in_success', [])
            )
            bad_steps = (
                action_annos.get('key_bad_in_fail', []) +
                action_annos.get('key_bad_in_success', [])
            )
            
            # 收集good_steps的标签 (真实标签=1)
            for step in good_steps:
                if 1 <= step <= step_num:
                    all_good_true_labels.append(1)
                    all_good_pred_labels.append(action_reward_list[step - 1])
                    if action_reward_list[step - 1] != 1:
                        pred_wrong_steps.append(step)
            
            # 收集bad_steps的标签 (真实标签=0)
            for step in bad_steps:
                if 1 <= step <= step_num:
                    all_bad_true_labels.append(0)
                    all_bad_pred_labels.append(action_reward_list[step - 1])
                    if action_reward_list[step - 1] != 0:
                        pred_wrong_steps.append(step)
            
            task["pred_wrong_steps"] = pred_wrong_steps
    
    all_true_labels = all_good_true_labels + all_bad_true_labels
    all_pred_labels = all_good_pred_labels + all_bad_pred_labels
    
    return all_results, all_true_labels, all_pred_labels

def parse_action_reward_by_task_type(all_results):
    """按task_type分组解析action reward，不修改原始数据"""
    # 1. 先对所有数据调用parse_action_reward（会修改all_results）
    all_results, all_true_labels, all_pred_labels = parse_action_reward(all_results)
    
    # 2. 按task_type分组收集标签数据
    task_type_labels = {}
    
    for task in all_results:
        task_type = task['task_type']
        if task_type not in task_type_labels:
            task_type_labels[task_type] = {'true_labels': [], 'pred_labels': []}
        
        # 从已经处理好的数据中提取标签
        action_annos = task.get('action_annos', {})
        action_reward_list = task.get('action_reward_list', [])
        step_num = task.get('step_num', 0)
        
        if action_annos:
            # 复用parse_action_reward中的标签收集逻辑
            good_steps = (
                action_annos.get('key_good_before_bad_in_fail', []) +
                action_annos.get('key_good_in_success', []) +
                action_annos.get('key_good_after_bad_in_success', [])
            )
            bad_steps = (
                action_annos.get('key_bad_in_fail', []) +
                action_annos.get('key_bad_in_success', [])
            )
            
            # 收集该task_type的标签
            for step in good_steps:
                if 1 <= step <= step_num:
                    task_type_labels[task_type]['true_labels'].append(1)
                    task_type_labels[task_type]['pred_labels'].append(action_reward_list[step - 1])
            
            for step in bad_steps:
                if 1 <= step <= step_num:
                    task_type_labels[task_type]['true_labels'].append(0)
                    task_type_labels[task_type]['pred_labels'].append(action_reward_list[step - 1])
    
    return all_results, all_true_labels, all_pred_labels, task_type_labels


def analyze_group_performance(all_results):
    """分析每个子类别的性能（准确率、错误率、比例）"""
    # 初始化子类别统计
    group_stats = {
        'key_good_before_bad_in_fail': {'correct': 0, 'total': 0},
        'key_good_in_success': {'correct': 0, 'total': 0},
        'key_good_after_bad_in_success': {'correct': 0, 'total': 0},
        'key_bad_in_fail': {'correct': 0, 'total': 0},
        'key_bad_in_success': {'correct': 0, 'total': 0}
    }
    
    for task in all_results:
        action_annos = task.get('action_annos', {})
        action_reward_list = task.get('action_reward_list', [])
        step_num = task.get('step_num', 0)
        
        # 处理每个子类别
        for group in group_stats.keys():
            steps = action_annos.get(group, [])
            for step in steps:
                if 1 <= step <= step_num:
                    group_stats[group]['total'] += 1
                    # 确定预期标签（good子类别为1，bad子类别为0）
                    expected = 1 if 'key_good' in group else 0
                    predicted = action_reward_list[step - 1]
                    if predicted == expected:
                        group_stats[group]['correct'] += 1
    
    # 计算每个子类别的准确率和错误率
    for group, stats in group_stats.items():
        if stats['total'] > 0:
            stats['accuracy'] = stats['correct'] / stats['total']
            stats['error_rate'] = 1 - stats['accuracy']
        else:
            stats['accuracy'] = 0
            stats['error_rate'] = 0
    
    # 计算总步数（用于计算比例）
    total_steps = sum(stats['total'] for stats in group_stats.values())
    for stats in group_stats.values():
        if total_steps > 0:
            stats['proportion'] = stats['total'] / total_steps
        else:
            stats['proportion'] = 0
    
    return group_stats

def parse_llm_output(all_results, evaluator, args):
    for task in all_results:
        analysis = task["llm_output"]
        if args.eval_mode == "zerogui":
            reward = evaluator.parse_from_response(analysis)
            task["pred"] = reward
        elif args.eval_mode in ["sewsm", "sewsm_w_action"]:
            reward, res_dict = evaluator.parse_from_response_sewsm(analysis)
            task["pred"] = reward
            task["res_dict"] = res_dict
        elif args.eval_mode in ["opencua_reflect","opencua_fulltraj","sewsm_targetar"]:
            res_dict_dict = {}
            for key, value in analysis.items():
                res_dict = evaluator.parse_from_response_opencua(value)
                res_dict_dict[key] = res_dict
            task["res_dict"] = res_dict_dict
        else:
            raise NotImplementedError(f"Unsupported evaluation mode: {args.eval_mode}")
    return all_results

def parse_from_response_sewsm(response: str):
    """Parse completion status from LLM response for sewsm mode
    
    Args:
        response: Raw response string from LLM
        
    Returns:
        Tuple of (reward, res_dict) where:
            reward: 1.0 if Correctness is True, 0.0 if False, -2.0 if error
            res_dict: The extracted dictionary from the response
    """
    try:
        # 尝试从响应中提取字典
        if '<res_dict>' in response:
            res_dict_str = response.split('<res_dict>')[1].split('</res_dict>')[0].strip()
            res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
            # 处理缺少花括号的情况
            if not res_dict_str.startswith('{'):
                res_dict_str = '{' + res_dict_str
            if not res_dict_str.endswith('}'):
                res_dict_str = res_dict_str + '}'
            res_dict = ast.literal_eval(res_dict_str)
        elif "\\boxed{\\text" in response:
            res_dict_str = response.split('\\boxed{\\text')[1].split('\\]')[0].strip()
            res_dict_str = res_dict_str.replace("true", "True").replace("false", "False")
            res_dict_str = res_dict_str.replace("First Error Step", "First_Error_Step").replace("Error Type", "Error_Type")
            res_dict_str = res_dict_str.replace("Correct Action", "Correct_Action").replace("N/A", "None").replace("}}", "}")
            res_dict_str = res_dict_str.replace('Correctness','"Correctness"')
            res_dict_str = res_dict_str.replace('Redundant','"Redundant"')
            res_dict_str = res_dict_str.replace('Optimized','"Optimized"')
            res_dict_str = res_dict_str.replace('First_Error_Step','"First_Error_Step"')
            res_dict_str = res_dict_str.replace('Error_Type','"Error_Type"')
            res_dict_str = res_dict_str.replace('Correct_Action','"Correct_Action"')
            res_dict = ast.literal_eval(res_dict_str)
        elif '```json' in response:
            res_dict_str = response.split('```json')[-1].split('```')[0].strip()
            res_dict = ast.literal_eval(res_dict_str)
        elif '```python' in response:
            res_dict_str = response.split('```python')[1].split('```')[0].strip()
            res_dict = ast.literal_eval(res_dict_str)
        else:
            raise ValueError("No valid marker found in response")
        
        
        # 检查字典中是否有'Correctness'键
        if 'Correctness' in res_dict:
            reward = 1.0 if res_dict['Correctness'] else 0.0
            return reward, res_dict
        else:
            raise KeyError("Key 'Correctness' not found in res_dict")
    except Exception as e:
        print(f"Error extracting res_dict: {e}")
        # 解析失败时只提取Correctness字段
        res_dict = {}
        try:
            # 尝试从响应中直接提取Correctness字段
            if '"Correctness": True' in response or "'Correctness': True" in response:
                res_dict['Correctness'] = True
                reward = 1.0
            elif '"Correctness": False' in response or "'Correctness': False" in response:
                res_dict['Correctness'] = False
                reward = 0.0
            else:
                # 如果连Correctness字段都没找到，返回-1
                return -2.0, {}
            return reward, res_dict
        except Exception as fallback_e:
            print(f"Fallback extraction failed: {fallback_e}")
            return -2.0, {}

def update_group_stats(group, task):
    """更新分组统计"""
    pred = task.get('pred', -1)
    gt = task.get('gt', -1)
    
    if pred == -1:
        group['fail'] += 1
    elif gt == 1 and pred == 1:
        group['tp'] += 1
    elif gt == 0 and pred == 1:
        group['fp'] += 1
    elif gt == 1 and pred == 0:
        group['fn'] += 1

def compute_metrics(stats_dict):
    """计算评估指标（准确率、精确率、召回率、F1）"""
    tp = stats_dict.get('tp', 0)
    fp = stats_dict.get('fp', 0)
    fn = stats_dict.get('fn', 0)
    total = stats_dict.get('total', 0)
    
    # 正确预测的负样本数 = 总样本数 - (TP + FP + FN) - 失败数
    correct_negatives = total - (tp + fp + fn + stats_dict.get('fail', 0))
    
    accuracy = (tp + correct_negatives) / total if total > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'accuracy': accuracy,
        'tp': tp,
        'fp': fp,
        'fn': fn,
        'fail': stats_dict.get('fail', 0),
        'total': total
    }

def compute_step_num_metrics(all_results):
    """计算按步数分组的指标"""
    sorted_tasks = sorted(all_results, key=lambda x: x['step_num'])
    total_tasks = len(sorted_tasks)
    split_index = total_tasks // 2
    
    # 初始化分组统计
    groups = {
        'less_half': {'tp': 0, 'fp': 0, 'fn': 0, 'fail': 0, 'total': split_index},
        'more_half': {'tp': 0, 'fp': 0, 'fn': 0, 'fail': 0, 'total': total_tasks - split_index}
    }
    
    # 统计前50%的任务
    for task in sorted_tasks[:split_index]:
        update_group_stats(groups['less_half'], task)
    
    # 统计后50%的任务
    for task in sorted_tasks[split_index:]:
        update_group_stats(groups['more_half'], task)
    
    # 计算每组指标
    for group in groups.values():
        metrics = compute_metrics(group)
        group.update(metrics)
    
    return groups

def collect_statistics(all_results, stats, task_type_stats):
    """从所有结果中收集统计信息
    
    Args:
        all_results: 包含所有评估结果的列表
        stats: 统计信息字典
        task_type_stats: 任务类型统计信息字典
    """
    for res in all_results:
        model_setting = res['model_setting']
        task_type = res['task_type']
        pred_result = res['pred']
        gt_result = res['gt']
        
        if pred_result is None or pred_result == -2:  # 评估失败
            stats[model_setting]['fail'] += 1
            stats['overall']['fail'] += 1
            task_type_stats[task_type]['fail'] += 1
        else:
            if gt_result == 1 and pred_result == 1:
                stats[model_setting]['tp'] += 1
                stats['overall']['tp'] += 1
                task_type_stats[task_type]['tp'] += 1
            elif gt_result == 0 and pred_result == 1:
                stats[model_setting]['fp'] += 1
                stats['overall']['fp'] += 1
                task_type_stats[task_type]['fp'] += 1
            elif gt_result == 1 and pred_result == 0:
                stats[model_setting]['fn'] += 1
                stats['overall']['fn'] += 1
                task_type_stats[task_type]['fn'] += 1
        
        stats[model_setting]['total'] += 1
        stats['overall']['total'] += 1
        task_type_stats[task_type]['total'] += 1

def compute_and_save_metrics(all_results, detailed_filepath, metrics_filename, args):
    """计算并保存所有评估指标"""
    # 收集统计信息
    stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'fail': 0, 'total': 0})
    task_type_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'fail': 0, 'total': 0})

    if "or" in args.reward_type:
        collect_statistics(all_results, stats, task_type_stats)
    # 1. 计算traj-level指标
    for key in list(stats.keys()):
        if stats[key]['total'] > 0:
            stats[key] = compute_metrics(stats[key])
    
    # 2. 计算task_type-wise traj-level指标
    for task_type in list(task_type_stats.keys()):
        task_type_stats[task_type] = compute_metrics(task_type_stats[task_type])
    
    # 3. 解析action reward并收集标签数据（包括按task_type分组）
    if 'ar' in args.reward_type:
        all_results, all_true_labels, all_pred_labels, task_type_labels = parse_action_reward_by_task_type(all_results)
    
    # 4. 保存详细结果（最终完整保存）
    try:
        with open(detailed_filepath, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, indent=2, ensure_ascii=False)
    except Exception as e:
        print(f"Error saving final detailed results: {str(e)}")
    
    # 5. 重构统计结构
    new_stats = {}
    
    # 构建trajectory_reward_metrics
    if "or" in args.reward_type:
        new_stats['trajectory_reward_metrics'] = {
            'overall': stats.get('overall', {}),
            'by_task_type': task_type_stats,
            'by_model_setting': {
                k: v for k, v in stats.items() if k != 'overall'
            },
            'by_step_num': compute_step_num_metrics(all_results)
        }
    
    # 构建action_reward_metrics
    if 'ar' in args.reward_type:
        # 计算按task_type的action reward指标
        task_type_action_metrics = {}
        for task_type, labels in task_type_labels.items():
            if labels['true_labels']:  # 确保有数据
                # 为每个task_type单独筛选结果数据
                task_type_results = [task for task in all_results if task['task_type'] == task_type]
                
                task_type_action_metrics[task_type] = {
                    'metrics': compute_binary_classification_metrics(
                        labels['true_labels'], 
                        labels['pred_labels']
                    ),
                    # 为每个task_type单独计算group metrics
                    'group_metrics': analyze_group_performance(task_type_results)
                }
        
        new_stats['action_reward_metrics'] = {
            'overall': compute_binary_classification_metrics(all_true_labels, all_pred_labels),
            'by_task_type': task_type_action_metrics,
            'by_reward_type': analyze_group_performance(all_results)
        }
    
    # 6. 保存指标结果
    metrics_filepath = os.path.join(args.output_dir, metrics_filename)
    try:
        with open(metrics_filepath, 'w') as f:
            json.dump(new_stats, f, indent=2,ensure_ascii=False)
    except Exception as e:
        print(f"Error saving metrics: {str(e)}")
    
    # 7. 打印总体结果
    if 'trajectory_reward_metrics' in new_stats:
        overall = new_stats['trajectory_reward_metrics']['overall']
        print("\nTrajectory Reward Prediction Metrics:")
        print(f"Precision: {overall['precision']:.4f}")
        print(f"Recall: {overall['recall']:.4f}")
        print(f"F1-score: {overall['f1']:.4f}")
        print(f"Accuracy: {overall['accuracy']:.4f}")
        print(f"TP: {overall['tp']}, FP: {overall['fp']}, FN: {overall['fn']}, \nFail: {overall['fail']}, Total: {overall['total']}")
        
    # 打印action reward指标
    if 'action_reward_metrics' in new_stats:
        ar_metrics = new_stats['action_reward_metrics']['overall']
        print("\nAction Reward Prediction Metrics:")
        print(f"Precision: {ar_metrics['precision']:.4f}")
        print(f"Recall: {ar_metrics['recall']:.4f}")
        print(f"F1-score: {ar_metrics['f1']:.4f}")
        print(f"Overall Accuracy: {ar_metrics['oa']:.4f}")
        print(f"Total Samples: {ar_metrics['total_samples']}")
    
    # 打印子类别指标
    if 'action_reward_metrics' in new_stats and 'by_reward_type' in new_stats['action_reward_metrics']:
        group_metrics = new_stats['action_reward_metrics']['by_reward_type']
        print("\nAction Reward Prediction Group Metrics:")
        for group, metrics in group_metrics.items():
            print(f"\n{group.replace('_', ' ').title()}:")
            print(f"  Accuracy: {metrics['accuracy']:.4f}")
            print(f"  Proportion: {metrics['proportion']:.4f}")
            print(f"  Correct: {metrics['correct']}/{metrics['total']}")
    
    # 打印按task_type的action reward指标
    if 'action_reward_metrics' in new_stats and 'by_task_type' in new_stats['action_reward_metrics']:
        task_type_ar_metrics = new_stats['action_reward_metrics']['by_task_type']
        print("\nAction Reward Prediction Metrics by Task Type:")
        for task_type, task_metrics in task_type_ar_metrics.items():
            metrics = task_metrics['metrics']
            print(f"\n{task_type.upper()}:")
            print(f"  Precision: {metrics['precision']:.4f}")
            print(f"  Recall: {metrics['recall']:.4f}")
            print(f"  F1-score: {metrics['f1']:.4f}")
            print(f"  Overall Accuracy: {metrics['oa']:.4f}")
            print(f"  Total Samples: {metrics['total_samples']}")
            
            # 打印该task_type的子类别指标
            group_metrics = task_metrics['group_metrics']
            print(f"  Group Metrics:")
            for group, group_stats in group_metrics.items():
                if group_stats['total'] > 0:
                    print(f"    {group.replace('_', ' ').title()}: {group_stats['accuracy']:.4f} ({group_stats['correct']}/{group_stats['total']})")

    print(f"详细结果保存在: {detailed_filepath}")
    print(f"指标结果保存在: {metrics_filepath}")
