"""
ProcessBench Error Detection using Uncertainty Drop
基于 Uncertainty Drop 进行过程级别的错误检测
"""

import argparse
import numpy as np
import os
import torch
import json
from collections import Counter
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
from datasets import load_dataset
import re
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split
import matplotlib
matplotlib.use('Agg')  # 使用非交互式后端
import matplotlib.pyplot as plt
import seaborn as sns


def calculate_ema_numpy(data, span=5):
    """EMA 平滑"""
    values = np.array(data)
    if len(values) == 0:
        return np.array([])
    alpha = 2 / (span + 1)
    ema_values = np.zeros_like(values)
    ema_values[0] = values[0]
    for t in range(1, len(values)):
        ema_values[t] = alpha * values[t] + (1 - alpha) * ema_values[t-1]
    return ema_values


def get_logprobs_from_step_data(step_data):
    """
    从 step_data 中提取 logprobs
    返回 (actual_token_logprob, all_logprobs_dict)
    """
    if step_data is None:
        return None, {}
    
    # vLLM 的 logprobs 是 top-k 的字典 {token_id: LogprobObj}
    all_logprobs = {}
    actual_token_logprob = None
    
    for token_id, logprob_obj in step_data.items():
        logprob = logprob_obj.logprob
        all_logprobs[token_id] = logprob
        # 通常第一个是实际生成的 token
        if actual_token_logprob is None:
            actual_token_logprob = logprob
    
    return actual_token_logprob, all_logprobs


def get_entropy_curve(logprobs_seq):
    """
    提取 Token 级别的香农熵曲线（负熵）
    返回负熵，使得熵越高（越不确定）-> 负熵越低 -> 形成 Drop
    """
    if not logprobs_seq:
        return []
    evidence_curve = []
    
    for step_data in logprobs_seq:
        if step_data is None:
            continue
        
        # vLLM 的 logprobs 是 top-k 的字典
        probs = []
        for obj in step_data.values():
            probs.append(np.exp(obj.logprob))
            
        probs = np.array(probs)
        
        # 归一化 (Top-K 近似)
        probs_sum = np.sum(probs)
        if probs_sum > 0:
            probs = probs / probs_sum
        
        # 计算香农熵 H = -sum(p * log(p))
        entropy = -np.sum(probs * np.log(probs + 1e-10))
        
        # 返回负熵
        evidence_curve.append(-entropy)
        
    return evidence_curve


def calculate_length_normalized_logprob(logprobs_seq):
    """
    计算 Length-normalized log probability
    返回每个 token 的平均 log probability（归一化到长度）
    """
    if not logprobs_seq:
        return 0.0
    
    total_logprob = 0.0
    count = 0
    
    for step_data in logprobs_seq:
        if step_data is None:
            continue
        actual_logprob, _ = get_logprobs_from_step_data(step_data)
        if actual_logprob is not None:
            total_logprob += actual_logprob
            count += 1
    
    if count == 0:
        return 0.0
    
    # 长度归一化的 log probability
    return total_logprob / count


def calculate_mean_logprob(logprobs_seq):
    """
    计算 Mean log probability
    返回所有 token 的平均 log probability
    """
    return calculate_length_normalized_logprob(logprobs_seq)


def calculate_max_logprob(logprobs_seq):
    """
    计算 Max log probability
    返回所有 token 中的最大 log probability
    """
    if not logprobs_seq:
        return 0.0
    
    max_logprob = float('-inf')
    
    for step_data in logprobs_seq:
        if step_data is None:
            continue
        actual_logprob, _ = get_logprobs_from_step_data(step_data)
        if actual_logprob is not None:
            max_logprob = max(max_logprob, actual_logprob)
    
    return max_logprob if max_logprob != float('-inf') else 0.0


def calculate_variance_logprob(logprobs_seq):
    """
    计算 Variance of log probabilities
    返回 log probabilities 的方差（方差越大，不确定性越高）
    """
    if not logprobs_seq:
        return 0.0
    
    logprobs = []
    for step_data in logprobs_seq:
        if step_data is None:
            continue
        actual_logprob, _ = get_logprobs_from_step_data(step_data)
        if actual_logprob is not None:
            logprobs.append(actual_logprob)
    
    if len(logprobs) == 0:
        return 0.0
    
    return float(np.var(logprobs))


def calculate_perplexity(logprobs_seq):
    """
    计算 Perplexity
    PPL = exp(-mean(logprob))
    值越大，不确定性越高
    """
    mean_logprob = calculate_mean_logprob(logprobs_seq)
    if mean_logprob == 0.0:
        return float('inf')
    
    perplexity = np.exp(-mean_logprob)
    return float(perplexity) if not np.isinf(perplexity) else 1e10


def calculate_confidence_score(logprobs_seq):
    """
    计算 Confidence Score
    基于熵的置信度：confidence = 1 - normalized_entropy
    值越大，置信度越高（不确定性越低）
    """
    if not logprobs_seq:
        return 0.0
    
    # 计算平均熵
    entropy_curve = get_entropy_curve(logprobs_seq)
    if len(entropy_curve) == 0:
        return 0.0
    
    # 负熵曲线，需要转换回熵
    mean_negative_entropy = np.mean(entropy_curve)
    mean_entropy = -mean_negative_entropy
    
    # 归一化熵（假设最大熵为 log(vocab_size)，这里用 log(20) 作为近似）
    max_entropy = np.log(20)  # 基于 top-k=20
    normalized_entropy = mean_entropy / max_entropy if max_entropy > 0 else 0.0
    normalized_entropy = min(normalized_entropy, 1.0)
    
    # 置信度 = 1 - 归一化熵
    confidence = 1.0 - normalized_entropy
    return float(confidence)


def get_evidence_curve(logprobs_seq):
    """
    Mahuan 方法：提取原始 Logprobs Sum 曲线（Evidence Curve）
    对每个 token 位置，计算 top-k logprobs 的和
    这代表该位置捕获的总概率质量
    """
    if not logprobs_seq:
        return []
    evidence_curve = []
    for step_data in logprobs_seq:
        if step_data is None:
            continue
        # 取 Top-K Sum (Raw Logits Sum, representing total probability mass captured)
        step_sum = sum(obj.logprob for obj in step_data.values())
        evidence_curve.append(step_sum)
    return evidence_curve


def calculate_evidence_drop(evidence_curve, ema_span=5, drop_k=5):
    """
    Mahuan 方法：计算 Risk Score (Based on Worst Drops in Evidence)
    基于 Evidence Curve 的下降幅度计算风险分数
    """
    curve = np.array(evidence_curve)
    if len(curve) < 2:
        return 0.0

    # 1. 应用 EMA 平滑
    smooth_curve = calculate_ema_numpy(curve, span=ema_span)
    
    # 2. 计算一阶差分 (变化率)
    diffs = np.diff(smooth_curve)
    if len(diffs) == 0:
        return 0.0

    # 3. 提取下降点 (Drops)
    drops = diffs[diffs < 0]
    if len(drops) == 0:
        return 0.0
    
    # 4. 找出最大的 K 个下降幅度 (Worst Drops)
    sorted_drops = np.sort(drops)  # 从小到大 (负值越小跌得越狠)
    worst_drops = sorted_drops[:drop_k]
    
    # 5. Risk = -Mean(Worst Drops) (值越大表示跌落越严重，风险越高)
    risk = -float(np.mean(worst_drops))
    if np.isnan(risk) or np.isinf(risk):
        risk = 0.0
    return risk


def calculate_uncertainty_measures(logprobs_seq, measures=['entropy_drop', 'length_norm_logprob'], 
                                     ema_span=5, drop_k=5, use_drop=True):
    """
    计算多种 uncertainty measures
    
    Args:
        logprobs_seq: logprobs 序列
        measures: 要计算的 measures 列表（基础 measure 名称，如 'entropy', 'evidence'）
        ema_span: EMA 平滑窗口大小（用于 drop-based measures）
        drop_k: 考虑的最大下降数量（用于 drop-based measures）
        use_drop: 如果 True，计算 drop 版本；如果 False，计算原始 measure
    
    Returns:
        dict: 包含各种 uncertainty measures 的字典
        如果 use_drop=True，返回 {measure_name_drop: value}
        如果 use_drop=False，返回 {measure_name: value}
    """
    results = {}
    
    if not logprobs_seq or len(logprobs_seq) == 0:
        for measure in measures:
            suffix = '_drop' if use_drop else ''
            results[f'{measure}{suffix}'] = 0.0
        return results
    
    # 规范化 measure 名称（移除 _drop 后缀）
    base_measures = [m.replace('_drop', '') for m in measures]
    base_measures = list(set(base_measures))  # 去重
    
    # 1. Entropy measures
    if 'entropy' in base_measures:
        entropy_curve = get_entropy_curve(logprobs_seq)
        if len(entropy_curve) > 0:
            if use_drop:
                # Drop 版本：基于熵曲线的下降
                drop_score = calculate_uncertainty_drop(entropy_curve, ema_span=ema_span, drop_k=drop_k)
                if drop_score == 0.0:
                    avg_negative_entropy = np.mean(entropy_curve)
                    drop_score = -avg_negative_entropy
                results['entropy_drop'] = float(drop_score)
            else:
                # 原始版本：平均熵
                results['entropy'] = float(-np.mean(entropy_curve))
        else:
            if use_drop:
                results['entropy_drop'] = 0.0
            else:
                results['entropy'] = 0.0
    
    # 2. Length-normalized log probability (原始 measure，不支持 drop)
    if 'length_norm_logprob' in base_measures:
        results['length_norm_logprob'] = calculate_length_normalized_logprob(logprobs_seq)
    
    # 3. Mean log probability (原始 measure，不支持 drop)
    if 'mean_logprob' in base_measures:
        results['mean_logprob'] = calculate_mean_logprob(logprobs_seq)
    
    # 4. Max log probability (原始 measure，不支持 drop)
    if 'max_logprob' in base_measures:
        results['max_logprob'] = calculate_max_logprob(logprobs_seq)
    
    # 5. Variance of log probabilities (原始 measure，不支持 drop)
    if 'variance_logprob' in base_measures:
        results['variance_logprob'] = calculate_variance_logprob(logprobs_seq)
    
    # 6. Perplexity (原始 measure，不支持 drop)
    if 'perplexity' in base_measures:
        results['perplexity'] = calculate_perplexity(logprobs_seq)
    
    # 7. Confidence score (原始 measure，不支持 drop)
    if 'confidence' in base_measures:
        results['confidence'] = calculate_confidence_score(logprobs_seq)
    
    # 8. Evidence measures (Mahuan method)
    if 'evidence' in base_measures or 'mahuan' in base_measures:
        evidence_curve = get_evidence_curve(logprobs_seq)
        if len(evidence_curve) > 0:
            if use_drop:
                # Drop 版本：基于 evidence curve 的下降
                risk_score = calculate_evidence_drop(evidence_curve, ema_span=ema_span, drop_k=drop_k)
                results['evidence_drop'] = float(risk_score)
                results['mahuan_risk'] = float(risk_score)  # 别名
            else:
                # 原始版本：mean evidence
                results['evidence'] = float(np.mean(evidence_curve))
                results['mahuan'] = float(np.mean(evidence_curve))  # 别名
        else:
            if use_drop:
                results['evidence_drop'] = 0.0
                results['mahuan_risk'] = 0.0
            else:
                results['evidence'] = 0.0
                results['mahuan'] = 0.0
    
    return results


def calculate_uncertainty_drop(entropy_curve, ema_span=5, drop_k=5):
    """
    计算 Uncertainty Drop Score
    基于熵曲线的下降幅度计算风险分数
    """
    curve = np.array(entropy_curve)
    if len(curve) < 2:
        return 0.0

    # 1. 应用 EMA 平滑
    smooth_curve = calculate_ema_numpy(curve, span=ema_span)
    
    # 2. 计算一阶差分
    diffs = np.diff(smooth_curve)
    if len(diffs) == 0:
        return 0.0

    # 3. 提取下降点 (Drops)
    drops = diffs[diffs < 0]
    if len(drops) == 0:
        return 0.0
    
    # 4. 找出最大的 K 个下降幅度
    sorted_drops = np.sort(drops)
    worst_drops = sorted_drops[:drop_k]
    
    # 5. Uncertainty Drop Score = -Mean(Worst Drops)
    # 值越大表示不确定性增加越严重
    drop_score = -float(np.mean(worst_drops))
    if np.isnan(drop_score) or np.isinf(drop_score):
        drop_score = 0.0
    return drop_score


def prepare_paragraph_prompt_with_target(problem, steps, paragraph_idx, target_step=None):
    """
    准备 prompt，包含问题和到当前段落为止的所有步骤
    如果提供了 target_step，则用于 forced decoding
    
    核心思想（新方法）：
    - 给定问题和到当前段落为止的所有步骤
    - 使用数据集中的"下一个段落"作为 target
    - 让模型按照 target 序列进行 forced decoding
    - 获取模型在生成这个特定序列时的 logprobs
    - 如果当前段落有错误，模型在生成后续段落时会表现出更高的不确定性
    """
    if paragraph_idx == 0:
        # 第一段：问题 + 第一段
        prompt = f"{problem}\n\nSolution:\n{steps[0]}"
    else:
        # 后续段落：问题 + 之前所有段落 + 当前段落
        previous_steps = "\n\n".join([f"{step}" for step in steps[:paragraph_idx+1]])
        prompt = f"{problem}\n\nSolution:\n{previous_steps}"
    
    return prompt, target_step


def get_logprobs_for_target_sequence_vllm(llm, tokenizer, prompt_text, target_text):
    """
    使用 vLLM 获取模型对特定序列的 logprobs
    
    方法：让模型生成 target_text，获取生成过程中的 logprobs
    注意：生成的内容可能不完全匹配 target，但 logprobs 仍然反映不确定性
    """
    # 使用 chat template
    messages = [{'role': 'user', 'content': prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )
    
    # Tokenize target 以确定生成长度
    target_tokens = tokenizer.encode(target_text, add_special_tokens=False)
    max_tokens = min(len(target_tokens) + 20, 512)
    
    # 让模型生成（temperature=0 确保确定性）
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=max_tokens,
        logprobs=20,
        stop_token_ids=[tokenizer.eos_token_id],
        stop=["\n\n\n"]
    )
    
    outputs = llm.generate([formatted_prompt], sampling_params)
    output = outputs[0]
    
    # 获取生成的 logprobs
    logprobs_seq = output.outputs[0].logprobs
    
    if logprobs_seq is not None and len(logprobs_seq) > 0:
        logprobs_seq = logprobs_seq[:min(len(logprobs_seq), len(target_tokens) + 10)]
    
    return logprobs_seq


def get_logprobs_for_target_sequence_transformers(model, tokenizer, prompt_text, target_text, device='cuda'):
    """
    使用 transformers 直接计算给定序列的 logprobs（更接近 forced decoding）
    
    这是更准确的方法：直接计算模型对 target 序列的 logprobs
    生成的序列和数据集序列完全一致
    """
    # 准备完整序列
    full_text = prompt_text + target_text
    
    # 使用 chat template
    messages = [{'role': 'user', 'content': prompt_text}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )
    
    # Tokenize
    prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False, return_tensors='pt').to(device)
    target_tokens = tokenizer.encode(target_text, add_special_tokens=False, return_tensors='pt').to(device)
    
    # 拼接
    input_ids = torch.cat([prompt_tokens, target_tokens], dim=1)
    
    # 前向传播获取 logits
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # 计算每个位置的 logprobs
    # logits shape: [batch_size, seq_len, vocab_size]
    # 我们需要计算每个 target token 位置的 logprobs
    logprobs_seq = []
    
    prompt_len = prompt_tokens.shape[1]
    for i in range(target_tokens.shape[1]):
        # 当前 target token 的位置
        pos = prompt_len + i
        
        # 获取该位置的 logits（预测下一个 token）
        if pos < logits.shape[1] - 1:
            next_pos = pos + 1
            logits_at_pos = logits[0, pos, :]  # [vocab_size]
            
            # 计算 log probabilities
            log_probs = F.log_softmax(logits_at_pos, dim=-1)
            
            # 获取实际 token 的 logprob
            actual_token_id = target_tokens[0, i].item()
            token_logprob = log_probs[actual_token_id].item()
            
            # 为了计算 entropy，我们需要 top-k logprobs
            # 创建一个简化的 logprobs dict（类似 vLLM 格式）
            top_k = 20
            top_logprobs, top_indices = torch.topk(log_probs, top_k)
            
            logprobs_dict = {}
            for j, (logprob, idx) in enumerate(zip(top_logprobs, top_indices)):
                logprobs_dict[idx.item()] = type('LogprobObj', (), {'logprob': logprob.item()})()
            
            logprobs_seq.append(logprobs_dict)
    
    return logprobs_seq


# def get_logprobs_for_target_sequence(llm, tokenizer, prompt_text, target_text, 
#                                      use_transformers=False, model=None, device='cuda'):
#     """
#     获取模型对特定序列的 logprobs
    
#     如果 use_transformers=True，使用 transformers 直接计算（forced decoding）
#     否则使用 vLLM 生成方式
#     """
#     if use_transformers and model is not None:
#         return get_logprobs_for_target_sequence_transformers(
#             model, tokenizer, prompt_text, target_text, device
#         )
#     else:
#         return get_logprobs_for_target_sequence_vllm(llm, tokenizer, prompt_text, target_text)


def process_sample_with_uncertainty_drop(llm, tokenizer, problem, steps, 
                                         ema_span=5, drop_k=5,
                                         use_transformers=False, model=None, device='cuda',
                                         uncertainty_measures=['entropy', 'length_norm_logprob'],
                                         use_drop=True, batch_paragraphs=True):
    """
    对单个样本逐段计算多种 uncertainty measures
    
    简化方法：
    1. 将整个 question + steps 组合成一个完整序列
    2. 使用 prompt_logprobs 参数直接获取 prompt 的 logprobs（无需生成 tokens）
    3. 通过构造 mask 来提取不同 step 的 logprobs
    
    Args:
        uncertainty_measures: 要计算的 measures 列表
        use_drop: 如果 True，计算 drop 版本；如果 False，计算原始 measure
    
    返回每段的 uncertainty measures 字典列表和预测的错误位置
    """
    paragraph_uncertainties = []
    
    # 1. 组合整个序列：problem + all steps
    full_text = f"{problem}\n\nSolution:\n" + "\n\n".join(steps)
    
    # 2. 使用 chat template
    messages = [{'role': 'user', 'content': full_text}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=False
    )
    
    # 3. Tokenize 整个序列，构造 step masks
    prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
    
    # 计算每个 step 在序列中的 token 位置范围
    # 先找到 "Solution:\n" 之后的位置
    solution_prefix = f"{problem}\n\nSolution:\n"
    solution_prefix_tokens = tokenizer.encode(solution_prefix, add_special_tokens=False)
    solution_prefix_len = len(solution_prefix_tokens)
    
    step_masks = []  # 每个 mask 表示对应 step 的 token 位置范围
    current_pos = solution_prefix_len
    
    for step_idx, step in enumerate(steps):
        step_text = step if step_idx == 0 else "\n\n" + step
        step_tokens = tokenizer.encode(step_text, add_special_tokens=False)
        step_start = current_pos
        step_end = current_pos + len(step_tokens)
        step_masks.append((step_start, step_end))
        current_pos = step_end
    
    # 4. 使用 vLLM 获取 prompt 的 logprobs（使用 prompt_logprobs 参数）
    # 只需要生成 1 个 token 来触发计算，但主要获取的是 prompt_logprobs
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=1,  # 只生成 0 个 token（触发计算）
        logprobs=20,
        prompt_logprobs=20,  # 获取 prompt 位置的 logprobs
    )
    
    outputs = llm.generate([formatted_prompt], sampling_params)
    output = outputs[0]
    
    # 5. 从 prompt_logprobs 中提取每个 step 的 logprobs
    prompt_logprobs = output.prompt_logprobs  # 这是一个列表，每个元素对应 prompt 的一个 token
    
    if prompt_logprobs is None or len(prompt_logprobs) == 0:
        # 如果没有 prompt_logprobs，fallback 到原来的方法
        for step_idx in range(len(steps)):
            empty_measures = {}
            for measure in uncertainty_measures:
                if measure in ['entropy', 'evidence', 'mahuan']:
                    empty_measures[f'{measure}_drop'] = 0.0
                    empty_measures[measure] = 0.0
                else:
                    empty_measures[measure] = 0.0
            paragraph_uncertainties.append(empty_measures)
    else:
        # 通过 mask 提取每个 step 的 logprobs
        for step_idx, (step_start, step_end) in enumerate(step_masks):
            # 提取该 step 范围内的 logprobs
            step_logprobs = []
            for pos in range(step_start, min(step_end, len(prompt_logprobs))):
                if prompt_logprobs[pos] is not None:
                    # prompt_logprobs[pos] 是一个字典 {token_id: LogprobObj}
                    step_logprobs.append(prompt_logprobs[pos])
            
            if len(step_logprobs) == 0:
                empty_measures = {}
                for measure in uncertainty_measures:
                    if measure in ['entropy', 'evidence', 'mahuan']:
                        empty_measures[f'{measure}_drop'] = 0.0
                        empty_measures[measure] = 0.0
                    else:
                        empty_measures[measure] = 0.0
                paragraph_uncertainties.append(empty_measures)
                continue
            
            # 计算 uncertainty measures
            if use_drop:
                measures_dict_drop = calculate_uncertainty_measures(
                    step_logprobs, measures=uncertainty_measures,
                    ema_span=ema_span, drop_k=drop_k, use_drop=True
                )
                measures_dict_raw = calculate_uncertainty_measures(
                    step_logprobs, measures=uncertainty_measures,
                    ema_span=ema_span, drop_k=drop_k, use_drop=False
                )
                paragraph_uncertainties.append({**measures_dict_drop, **measures_dict_raw})
            else:
                measures_dict = calculate_uncertainty_measures(
                    step_logprobs, measures=uncertainty_measures,
                    ema_span=ema_span, drop_k=drop_k, use_drop=False
                )
                paragraph_uncertainties.append(measures_dict)
    
    # 根据 uncertainty measures 预测错误位置
    # 使用主要的 uncertainty measure（默认使用第一个，优先使用 drop 版本）
    primary_measure = None
    if paragraph_uncertainties and len(paragraph_uncertainties) > 0:
        available_measures = list(paragraph_uncertainties[0].keys())
        if uncertainty_measures:
            first_measure = uncertainty_measures[0]
            # 优先使用 drop 版本
            if f'{first_measure}_drop' in available_measures:
                primary_measure = f'{first_measure}_drop'
            elif first_measure in available_measures:
                primary_measure = first_measure
            else:
                # 如果第一个 measure 不可用，尝试其他 drop 版本
                drop_measures = [m for m in available_measures if m.endswith('_drop')]
                primary_measure = drop_measures[0] if drop_measures else available_measures[0]
        else:
            # 如果没有指定，优先选择 drop 版本
            drop_measures = [m for m in available_measures if m.endswith('_drop')]
            primary_measure = drop_measures[0] if drop_measures else available_measures[0]
    else:
        primary_measure = 'entropy_drop'
    
    if len(paragraph_uncertainties) == 0:
        predicted_error_idx = -1
    else:
        # 提取主要 measure 的值
        primary_values = []
        for para_measures in paragraph_uncertainties:
            if isinstance(para_measures, dict):
                value = para_measures.get(primary_measure, 0.0)
            else:
                # 兼容旧格式（单个数值）
                value = para_measures
            primary_values.append(value)
        
        primary_values = np.array(primary_values)
        
        # 对于不同的 measures，判断错误的方式可能不同
        # entropy_drop, variance_logprob, perplexity, mean_entropy: 值越大，不确定性越高
        # length_norm_logprob, mean_logprob, max_logprob, confidence: 值越小，不确定性越高
        
        # 根据不同的 measure 类型判断错误
        # 值越大，不确定性越高的 measures（包括 drop 和原始版本）
        high_uncertainty_measures = ['entropy_drop', 'entropy', 'variance_logprob', 'perplexity', 
                                     'evidence_drop', 'mahuan_risk', 'evidence', 'mahuan']
        # 值越小，不确定性越高的 measures
        low_uncertainty_measures = ['length_norm_logprob', 'mean_logprob', 'max_logprob', 
                                    'confidence']
        
        if primary_measure in high_uncertainty_measures:
            # 值越大，不确定性越高
            mean_uncertainty = np.mean(primary_values)
            std_uncertainty = np.std(primary_values)
            threshold = mean_uncertainty + std_uncertainty
            error_indices = np.where(primary_values >= threshold)[0]
        else:
            # 值越小，不确定性越高
            mean_uncertainty = np.mean(primary_values)
            std_uncertainty = np.std(primary_values)
            threshold = mean_uncertainty - std_uncertainty
            error_indices = np.where(primary_values <= threshold)[0]
        
        if len(error_indices) > 0:
            predicted_error_idx = int(error_indices[0])  # 最早的错误
        else:
            # 如果没有超过阈值的，使用最极端 uncertainty 的段落
            if primary_measure in high_uncertainty_measures:
                max_idx = int(np.argmax(primary_values))
                if primary_values[max_idx] > mean_uncertainty:
                    predicted_error_idx = max_idx
                else:
                    predicted_error_idx = -1
            else:
                min_idx = int(np.argmin(primary_values))
                if primary_values[min_idx] < mean_uncertainty:
                    predicted_error_idx = min_idx
                else:
                    predicted_error_idx = -1
    
    return paragraph_uncertainties, predicted_error_idx


def find_first_error_by_threshold(uncertainties, threshold):
    """
    根据阈值找到第一个超过阈值的段落（最早错误）
    """
    for idx, unc in enumerate(uncertainties):
        if unc >= threshold:
            return idx
    return -1


def collect_error_step_uncertainties(paragraph_uncertainties_list, labels, measure_name):
    """
    收集 calibration set 上错误步骤的 uncertainty 值
    
    Args:
        paragraph_uncertainties_list: 所有样本的段落 uncertainty 列表
        labels: 真实的错误位置列表
        measure_name: 要使用的 measure 名称
    
    Returns:
        error_uncertainties: 错误步骤的 uncertainty 值列表
    """
    error_uncertainties = []
    
    for sample_uncertainties, label in zip(paragraph_uncertainties_list, labels):
        # 只处理有错误的样本
        if label == -1:
            continue
        
        if not sample_uncertainties or len(sample_uncertainties) == 0:
            continue
        
        # 提取错误步骤的 uncertainty 值
        if label < len(sample_uncertainties):
            para_measures = sample_uncertainties[label]
            if isinstance(para_measures, dict):
                value = para_measures.get(measure_name, None)
                if value is not None:
                    error_uncertainties.append(value)
    
    return np.array(error_uncertainties)


def collect_all_step_uncertainties(paragraph_uncertainties_list, labels, measure_name):
    """
    收集所有步骤的 uncertainty 值，区分 correct 和 error
    
    Args:
        paragraph_uncertainties_list: 所有样本的段落 uncertainty 列表
        labels: 真实的错误位置列表（-1 表示无错误）
        measure_name: 要使用的 measure 名称
    
    Returns:
        error_uncertainties: 错误步骤的 uncertainty 值列表
        correct_uncertainties: 正确步骤的 uncertainty 值列表
    """
    error_uncertainties = []
    correct_uncertainties = []
    
    for sample_uncertainties, label in zip(paragraph_uncertainties_list, labels):
        if not sample_uncertainties or len(sample_uncertainties) == 0:
            continue
        
        for step_idx, para_measures in enumerate(sample_uncertainties):
            if isinstance(para_measures, dict):
                value = para_measures.get(measure_name, None)
                if value is not None:
                    # 判断该步骤是否为错误步骤
                    if label != -1 and step_idx == label:
                        # 这是错误步骤
                        error_uncertainties.append(value)
                    else:
                        # 这是正确步骤
                        correct_uncertainties.append(value)
    
    return np.array(error_uncertainties), np.array(correct_uncertainties)


def calculate_threshold_from_calibration(error_uncertainties, measure_name, 
                                         significant_level=0.05, fallback_percentile=95):
    """
    基于假设检验计算决策阈值
    
    核心思想：
    - 在 calibration set 上收集错误步骤的 uncertainty 分布
    - 使用 significant_level (默认 0.05) 计算阈值
    - 对于 high_uncertainty_measures：阈值 = (1 - significant_level) 分位数
    - 对于 low_uncertainty_measures：阈值 = significant_level 分位数
    
    Args:
        error_uncertainties: 错误步骤的 uncertainty 值数组
        measure_name: measure 名称
        significant_level: 显著性水平（默认 0.05）
        fallback_percentile: 如果没有错误样本时的回退百分位数
    
    Returns:
        threshold: 决策阈值
    """
    if len(error_uncertainties) == 0:
        # 如果没有错误样本，使用回退策略
        print(f"Warning: No error samples found for {measure_name}, using fallback percentile {fallback_percentile}")
        return None
    
    # 判断 measure 类型
    high_uncertainty_measures = ['entropy_drop', 'entropy', 'variance_logprob', 'perplexity', 
                                 'evidence_drop', 'mahuan_risk', 'evidence', 'mahuan']
    low_uncertainty_measures = ['length_norm_logprob', 'mean_logprob', 'max_logprob', 
                                'confidence']
    
    if measure_name in high_uncertainty_measures:
        # 值越大，不确定性越高
        # 阈值应该是 significant_level 分位数（例如 α=0.05 时，取 5% 分位数）
        # 这样有 (1 - significant_level) = 95% 的概率，错误步骤的 uncertainty >= threshold
        percentile = significant_level * 100
        threshold = np.percentile(error_uncertainties, percentile)
    else:
        # 值越小，不确定性越高
        # 阈值应该是 (1 - significant_level) 分位数（例如 α=0.05 时，取 95% 分位数）
        # 这样有 (1 - significant_level) = 95% 的概率，错误步骤的 uncertainty <= threshold
        percentile = (1 - significant_level) * 100
        threshold = np.percentile(error_uncertainties, percentile)
    
    return float(threshold)


def predict_error_with_measure(paragraph_uncertainties_list, measure_name, threshold=None):
    """
    使用指定的 measure 预测错误位置
    
    Args:
        paragraph_uncertainties_list: 所有样本的段落 uncertainty 列表
        measure_name: 要使用的 measure 名称
        threshold: 决策阈值（如果为 None，则使用 mean±std 方法）
    
    Returns:
        predictions: 预测的错误位置列表
    """
    predictions = []
    
    # 判断 measure 类型
    high_uncertainty_measures = ['entropy_drop', 'entropy', 'variance_logprob', 'perplexity', 
                                 'evidence_drop', 'mahuan_risk', 'evidence', 'mahuan']
    low_uncertainty_measures = ['length_norm_logprob', 'mean_logprob', 'max_logprob', 
                                'confidence']
    
    for sample_uncertainties in paragraph_uncertainties_list:
        if not sample_uncertainties or len(sample_uncertainties) == 0:
            predictions.append(-1)
            continue
        
        # 提取该 measure 的值
        measure_values = []
        for para_measures in sample_uncertainties:
            if isinstance(para_measures, dict):
                value = para_measures.get(measure_name, None)
                if value is None:
                    # 如果该 measure 不存在，跳过这个样本
                    break
                measure_values.append(value)
            else:
                # 兼容旧格式
                measure_values.append(para_measures)
        
        if len(measure_values) == 0:
            predictions.append(-1)
            continue
        
        measure_values = np.array(measure_values)
        
        # 使用阈值或 mean±std 方法
        if threshold is not None:
            # 使用预设阈值
            if measure_name in high_uncertainty_measures:
                # 值越大，不确定性越高
                error_indices = np.where(measure_values >= threshold)[0]
            else:
                # 值越小，不确定性越高
                error_indices = np.where(measure_values <= threshold)[0]
            
            if len(error_indices) > 0:
                predicted_error_idx = int(error_indices[0])
            else:
                predicted_error_idx = -1
        else:
            # 使用原来的 mean±std 方法（向后兼容）
            if measure_name in high_uncertainty_measures:
                # 值越大，不确定性越高
                mean_uncertainty = np.mean(measure_values)
                std_uncertainty = np.std(measure_values)
                threshold_local = mean_uncertainty + std_uncertainty
                error_indices = np.where(measure_values >= threshold_local)[0]
                
                if len(error_indices) > 0:
                    predicted_error_idx = int(error_indices[0])
                else:
                    max_idx = int(np.argmax(measure_values))
                    if measure_values[max_idx] > mean_uncertainty:
                        predicted_error_idx = max_idx
                    else:
                        predicted_error_idx = -1
            else:
                # 值越小，不确定性越高
                mean_uncertainty = np.mean(measure_values)
                std_uncertainty = np.std(measure_values)
                threshold_local = mean_uncertainty - std_uncertainty
                error_indices = np.where(measure_values <= threshold_local)[0]
                
                if len(error_indices) > 0:
                    predicted_error_idx = int(error_indices[0])
                else:
                    min_idx = int(np.argmin(measure_values))
                    if measure_values[min_idx] < mean_uncertainty:
                        predicted_error_idx = min_idx
                    else:
                        predicted_error_idx = -1
        
        predictions.append(predicted_error_idx)
    
    return predictions


def plot_uncertainty_distributions(cal_error_unc, cal_correct_unc, test_error_unc, test_correct_unc,
                                   threshold, measure_name, output_dir, config_name):
    """
    绘制校准集和测试集上 correct 和 error step 的 uncertainty 分布图
    
    Args:
        cal_error_unc: 校准集上错误步骤的 uncertainty 值
        cal_correct_unc: 校准集上正确步骤的 uncertainty 值
        test_error_unc: 测试集上错误步骤的 uncertainty 值
        test_correct_unc: 测试集上正确步骤的 uncertainty 值
        threshold: 决策阈值
        measure_name: measure 名称
        output_dir: 输出目录
        config_name: 配置名称
    """
    # 检查是否有数据
    if len(cal_error_unc) == 0 and len(cal_correct_unc) == 0:
        print(f"  Warning: No data for {measure_name}, skipping plot")
        return
    
    # 判断 measure 类型
    high_uncertainty_measures = ['entropy_drop', 'entropy', 'variance_logprob', 'perplexity', 
                                 'evidence_drop', 'mahuan_risk', 'evidence', 'mahuan']
    
    # 创建图形
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle(f'Uncertainty Distribution: {measure_name}', fontsize=32, fontweight='bold', y=0.98)
    
    # 设置样式
    sns.set_style("whitegrid")
    
    # 计算全局 x 轴范围和 bins（确保两个子图使用相同的范围）
    all_values = np.concatenate([
        cal_correct_unc if len(cal_correct_unc) > 0 else [],
        cal_error_unc if len(cal_error_unc) > 0 else [],
        test_correct_unc if len(test_correct_unc) > 0 else [],
        test_error_unc if len(test_error_unc) > 0 else []
    ])
    
    if len(all_values) == 0:
        print(f"  Warning: No valid data for {measure_name}, skipping plot")
        plt.close()
        return
    
    # 计算 x 轴范围（添加一些 padding）
    x_min, x_max = np.min(all_values), np.max(all_values)
    x_range = x_max - x_min
    x_min_padded = x_min - 0.05 * x_range
    x_max_padded = x_max + 0.05 * x_range
    
    # 计算合适的 bins 数量
    bins = min(50, max(20, int(np.sqrt(len(all_values)))))
    
    # ========== 绘制校准集分布 ==========
    ax1 = axes[0]
    
    if len(cal_correct_unc) > 0:
        ax1.hist(cal_correct_unc, bins=bins, alpha=0.6, label=f'Correct Steps (n={len(cal_correct_unc)})', 
                color='green', density=True, edgecolor='black', linewidth=0.8)
        # 添加均值线
        mean_correct = np.mean(cal_correct_unc)
        ax1.axvline(x=mean_correct, color='darkgreen', linestyle=':', linewidth=3, alpha=0.7)
    
    if len(cal_error_unc) > 0:
        ax1.hist(cal_error_unc, bins=bins, alpha=0.6, label=f'Error Steps (n={len(cal_error_unc)})', 
                color='red', density=True, edgecolor='black', linewidth=0.8)
        # 添加均值线
        mean_error = np.mean(cal_error_unc)
        ax1.axvline(x=mean_error, color='darkred', linestyle=':', linewidth=3, alpha=0.7)
    
    if threshold is not None:
        ax1.axvline(x=threshold, color='blue', linestyle='--', linewidth=4, 
                   label=f'Threshold = {threshold:.4f}', zorder=10)
    
    # 设置 x 轴范围
    ax1.set_xlim(x_min_padded, x_max_padded)
    
    ax1.set_xlabel(f'{measure_name} Value', fontsize=22, labelpad=10)
    ax1.set_ylabel('Density', fontsize=22, labelpad=10)
    ax1.set_title('Calibration Set', fontsize=24, fontweight='bold', pad=15)
    ax1.legend(fontsize=16, loc='best', frameon=True, shadow=True)
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='both', which='major', labelsize=18, length=6, width=1.5)
    
    # 添加统计信息
    stats_text = []
    if len(cal_correct_unc) > 0:
        stats_text.append(f'Correct: μ={np.mean(cal_correct_unc):.4f}, σ={np.std(cal_correct_unc):.4f}')
    if len(cal_error_unc) > 0:
        stats_text.append(f'Error: μ={np.mean(cal_error_unc):.4f}, σ={np.std(cal_error_unc):.4f}')
    if stats_text:
        ax1.text(0.02, 0.98, '\n'.join(stats_text), transform=ax1.transAxes,
                fontsize=14, verticalalignment='top', bbox=dict(boxstyle='round', 
                facecolor='wheat', alpha=0.7, pad=0.8))
    
    # ========== 绘制测试集分布 ==========
    ax2 = axes[1]
    
    if len(test_correct_unc) > 0:
        ax2.hist(test_correct_unc, bins=bins, alpha=0.6, label=f'Correct Steps (n={len(test_correct_unc)})', 
                color='green', density=True, edgecolor='black', linewidth=0.8)
        # 添加均值线
        mean_correct = np.mean(test_correct_unc)
        ax2.axvline(x=mean_correct, color='darkgreen', linestyle=':', linewidth=3, alpha=0.7)
    
    if len(test_error_unc) > 0:
        ax2.hist(test_error_unc, bins=bins, alpha=0.6, label=f'Error Steps (n={len(test_error_unc)})', 
                color='red', density=True, edgecolor='black', linewidth=0.8)
        # 添加均值线
        mean_error = np.mean(test_error_unc)
        ax2.axvline(x=mean_error, color='darkred', linestyle=':', linewidth=3, alpha=0.7)
    
    if threshold is not None:
        ax2.axvline(x=threshold, color='blue', linestyle='--', linewidth=4, 
                   label=f'Threshold = {threshold:.4f}', zorder=10)
    
    # 设置 x 轴范围（与左图保持一致）
    ax2.set_xlim(x_min_padded, x_max_padded)
    
    ax2.set_xlabel(f'{measure_name} Value', fontsize=22, labelpad=10)
    ax2.set_ylabel('Density', fontsize=22, labelpad=10)
    ax2.set_title('Test Set', fontsize=24, fontweight='bold', pad=15)
    ax2.legend(fontsize=16, loc='best', frameon=True, shadow=True)
    ax2.grid(True, alpha=0.3)
    ax2.tick_params(axis='both', which='major', labelsize=18, length=6, width=1.5)
    
    # 添加统计信息
    stats_text = []
    if len(test_correct_unc) > 0:
        stats_text.append(f'Correct: μ={np.mean(test_correct_unc):.4f}, σ={np.std(test_correct_unc):.4f}')
    if len(test_error_unc) > 0:
        stats_text.append(f'Error: μ={np.mean(test_error_unc):.4f}, σ={np.std(test_error_unc):.4f}')
    if stats_text:
        ax2.text(0.02, 0.98, '\n'.join(stats_text), transform=ax2.transAxes,
                fontsize=14, verticalalignment='top', bbox=dict(boxstyle='round', 
                facecolor='wheat', alpha=0.7, pad=0.8))
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图片
    plot_path = os.path.join(output_dir, f'{config_name}_{measure_name}_distribution.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  Distribution plot saved to: {plot_path}")


def evaluate_stepwise_predictions(uncertainties_list, labels, measure_name, threshold):
    """
    计算步骤级（step-wise）的评价指标
    
    Args:
        uncertainties_list: 所有样本的段落 uncertainty 列表
        labels: 真实的错误位置列表（-1 表示无错误）
        measure_name: 使用的 measure 名称
        threshold: 决策阈值
    
    Returns:
        dict: 包含步骤级指标的字典
    """
    # 判断 measure 类型
    high_uncertainty_measures = ['entropy_drop', 'entropy', 'variance_logprob', 'perplexity', 
                                 'evidence_drop', 'mahuan_risk', 'evidence', 'mahuan']
    
    # 收集所有步骤的真实标签和预测
    step_true_labels = []  # 1=错误步骤, 0=正确步骤
    step_predictions = []  # 1=预测为错误, 0=预测为正确
    step_uncertainties = []  # uncertainty 值
    
    for sample_idx, (sample_uncertainties, label) in enumerate(zip(uncertainties_list, labels)):
        if not sample_uncertainties or len(sample_uncertainties) == 0:
            continue
        
        for step_idx, para_measures in enumerate(sample_uncertainties):
            if isinstance(para_measures, dict):
                uncertainty = para_measures.get(measure_name, None)
                if uncertainty is None:
                    continue
                
                # 判断该步骤是否为错误步骤
                is_error_step = (label != -1 and step_idx == label)
                step_true_labels.append(1 if is_error_step else 0)
                
                # 根据阈值判断预测结果
                if measure_name in high_uncertainty_measures:
                    # 值越大，不确定性越高
                    predicted_error = 1 if uncertainty >= threshold else 0
                else:
                    # 值越小，不确定性越高
                    predicted_error = 1 if uncertainty <= threshold else 0
                
                step_predictions.append(predicted_error)
                step_uncertainties.append(uncertainty)
    
    if len(step_true_labels) == 0:
        return {
            'step_accuracy': 0.0,
            'step_precision': 0.0,
            'step_recall': 0.0,
            'step_f1': 0.0,
            'step_specificity': 0.0,
            'step_error_recall': 0.0,
            'step_correct_specificity': 0.0,
            'num_total_steps': 0,
            'num_error_steps': 0,
            'num_correct_steps': 0
        }
    
    step_true_labels = np.array(step_true_labels)
    step_predictions = np.array(step_predictions)
    
    # 计算混淆矩阵
    tp = np.sum((step_true_labels == 1) & (step_predictions == 1))  # 正确识别的错误步骤
    fp = np.sum((step_true_labels == 0) & (step_predictions == 1))  # 误判为错误的正确步骤
    tn = np.sum((step_true_labels == 0) & (step_predictions == 0))  # 正确识别的正确步骤
    fn = np.sum((step_true_labels == 1) & (step_predictions == 0))  # 漏掉的错误步骤
    
    # 计算指标
    accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    return {
        'step_accuracy': float(accuracy * 100),
        'step_precision': float(precision * 100),
        'step_recall': float(recall * 100),
        'step_f1': float(f1 * 100),
        'step_specificity': float(specificity * 100),
        'step_error_recall': float(recall * 100),  # 别名
        'step_correct_specificity': float(specificity * 100),  # 别名
        'num_total_steps': int(tp + fp + tn + fn),
        'num_error_steps': int(tp + fn),
        'num_correct_steps': int(tn + fp),
        'confusion_matrix': {
            'tp': int(tp),
            'fp': int(fp),
            'tn': int(tn),
            'fn': int(fn)
        }
    }


def evaluate_predictions(predictions, labels, uncertainties_list, output_dir, config_name, measure_name=None, threshold=None):
    """
    计算各种评价指标（包括样本级和步骤级）
    
    Args:
        predictions: 预测的错误位置列表
        labels: 真实的错误位置列表
        uncertainties_list: 所有样本的段落 uncertainty 列表
        output_dir: 输出目录
        config_name: 配置名称
        measure_name: 使用的 measure 名称（用于二分类评估）
        threshold: 决策阈值（用于步骤级评估）
    """
    # ============ 样本级指标（Response-wise） ============
    # 基本指标：准确率
    correct_predictions = [p == l for p, l in zip(predictions, labels)]
    accuracy = np.mean(correct_predictions) * 100
    
    # 分离 error 和 correct 样本
    error_indices = [i for i, label in enumerate(labels) if label != -1]
    correct_indices = [i for i, label in enumerate(labels) if label == -1]
    
    error_acc = np.mean([correct_predictions[i] for i in error_indices]) * 100 if error_indices else 0.0
    correct_acc = np.mean([correct_predictions[i] for i in correct_indices]) * 100 if correct_indices else 0.0
    f1 = 2 * error_acc * correct_acc / (error_acc + correct_acc) if (error_acc + correct_acc) > 0 else 0.0
    
    # 诊断信息：分析 correct 样本的步骤数量分布
    if correct_indices and uncertainties_list:
        num_steps_per_sample = []
        for i in correct_indices:
            if uncertainties_list[i]:
                num_steps_per_sample.append(len(uncertainties_list[i]))
        if num_steps_per_sample:
            avg_steps = np.mean(num_steps_per_sample)
            # 计算多重检验的理论误判率
            alpha = 0.05
            theoretical_false_positive = 1 - (1 - alpha) ** avg_steps
            print(f"    [Response-wise] Correct samples: avg {avg_steps:.1f} steps/sample")
            print(f"    [Response-wise] Multiple testing issue: expected {theoretical_false_positive*100:.1f}% false positive rate")
    
    # 二分类指标：是否有错误 (label != -1 表示有错误)
    binary_labels = [1 if label != -1 else 0 for label in labels]
    
    # 使用指定 measure 的最大 uncertainty 作为二分类分数
    max_uncertainties = []
    if measure_name and uncertainties_list and len(uncertainties_list) > 0:
        for unc_dict_list in uncertainties_list:
            if unc_dict_list and len(unc_dict_list) > 0:
                # 找到该样本中该 measure 的最大值
                max_val = None
                for para_measures in unc_dict_list:
                    if isinstance(para_measures, dict) and measure_name in para_measures:
                        val = para_measures[measure_name]
                        if max_val is None or val > max_val:
                            max_val = val
                max_uncertainties.append(max_val if max_val is not None else 0.0)
            else:
                max_uncertainties.append(0.0)
    else:
        max_uncertainties = [0.0] * len(uncertainties_list)
    
    # AUC-ROC
    try:
        if len(set(binary_labels)) > 1 and len(set(max_uncertainties)) > 1:
            auc_roc = roc_auc_score(binary_labels, max_uncertainties)
        else:
            auc_roc = 0.0
    except:
        auc_roc = 0.0
    
    # AUC-PR
    try:
        if len(set(binary_labels)) > 1 and len(set(max_uncertainties)) > 1:
            precision, recall, _ = precision_recall_curve(binary_labels, max_uncertainties)
            auc_pr = auc(recall, precision)
        else:
            auc_pr = 0.0
    except:
        auc_pr = 0.0
    
    # 段落级别的指标：对于有错误的样本，预测的错误段落位置是否准确
    error_position_acc = 0.0
    if error_indices:
        position_correct = [
            predictions[i] == labels[i] 
            for i in error_indices
        ]
        error_position_acc = np.mean(position_correct) * 100
    
    # ============ 步骤级指标（Step-wise） ============
    stepwise_results = {}
    if threshold is not None and measure_name is not None:
        stepwise_results = evaluate_stepwise_predictions(
            uncertainties_list, labels, measure_name, threshold
        )
        print(f"    [Step-wise] Accuracy: {stepwise_results['step_accuracy']:.2f}%")
        print(f"    [Step-wise] Precision: {stepwise_results['step_precision']:.2f}% | Recall: {stepwise_results['step_recall']:.2f}% | F1: {stepwise_results['step_f1']:.2f}%")
        print(f"    [Step-wise] Error Recall: {stepwise_results['step_error_recall']:.2f}% | Correct Specificity: {stepwise_results['step_correct_specificity']:.2f}%")
        print(f"    [Step-wise] Total steps: {stepwise_results['num_total_steps']} (Error: {stepwise_results['num_error_steps']}, Correct: {stepwise_results['num_correct_steps']})")
    
    # 组合样本级和步骤级结果
    results = {
        # 样本级指标（Response-wise）
        'response_overall_accuracy': accuracy,
        'response_error_accuracy': error_acc,
        'response_correct_accuracy': correct_acc,
        'response_f1_score': f1,
        'response_error_position_accuracy': error_position_acc,
        
        # 向后兼容的字段名
        'overall_accuracy': accuracy,
        'error_accuracy': error_acc,
        'correct_accuracy': correct_acc,
        'f1_score': f1,
        'error_position_accuracy': error_position_acc,
        
        # 二分类指标
        'auc_roc': auc_roc,
        'auc_pr': auc_pr,
        
        # 样本数量
        'num_samples': len(labels),
        'num_errors': len(error_indices),
        'num_correct': len(correct_indices)
    }
    
    # 添加步骤级指标
    if stepwise_results:
        results.update(stepwise_results)
    
    return results


def main():
    parser = argparse.ArgumentParser(description='ProcessBench Error Detection using Uncertainty Drop')
    parser.add_argument('--configs', type=str, nargs='+', default=['gsm8k', 'math'],
                        choices=['gsm8k', 'math', 'olympiadbench', 'omnimath'],
                        help='Dataset configurations to evaluate')
    parser.add_argument('--model_path', type=str, default='/home/linyeli/cqj/exercise/Qwen_model/Qwen3-8B',
                        help='Path to the model')
    parser.add_argument('--output_dir', type=str, default='./outputs',
                        help='Output directory for results')
    parser.add_argument('--ema_span', type=float, default=5.0,
                        help='EMA smoothing span for uncertainty calculation')
    parser.add_argument('--drop_k', type=int, default=5,
                        help='Number of worst drops to consider')
    parser.add_argument('--max_tokens', type=int, default=512,
                        help='Max tokens per paragraph generation (deprecated, now using forced decoding)')
    parser.add_argument('--threshold', type=float, default=None,
                        help='Uncertainty threshold for error detection (None for auto)')
    parser.add_argument('--use_max_uncertainty', action='store_true',
                        help='Use max uncertainty paragraph as error (default strategy)')
    parser.add_argument('--use_transformers', action='store_true',
                        help='Use transformers for forced decoding (more accurate, but slower)')
    parser.add_argument('--uncertainty_measures', type=str, nargs='+', 
                        default=["entropy", "evidence", "length_norm_logprob", "mean_logprob", "perplexity", "confidence"],
                        choices=['entropy', 'evidence', 'mahuan', 'length_norm_logprob', 
                                'mean_logprob', 'max_logprob', 'variance_logprob', 
                                'perplexity', 'confidence'],
                        help='Base uncertainty measures to compute (will generate both drop and raw versions for supported measures)')
    parser.add_argument('--use_drop', action='store_true', default=True,
                        help='Calculate both drop and raw versions for supported measures (default: True, generates N*2 measures)')
    parser.add_argument('--drop_only', action='store_true', default=False,
                        help='Only calculate drop versions (overrides --use_drop)')
    parser.add_argument('--batch_paragraphs', action='store_true', default=True,
                        help='Batch process paragraphs to improve GPU utilization (default: True)')
    parser.add_argument('--tensor_parallel_size', type=int, default=None,
                        help='Number of GPUs for tensor parallelism (default: auto-detect, use all available GPUs)')
    parser.add_argument('--gpu_memory_utilization', type=float, default=0.8,
                        help='GPU memory utilization ratio (default: 0.8)')
    parser.add_argument('--max_num_seqs', type=int, default=20,
                        help='Maximum number of sequences to process in parallel (default: 20)')
    parser.add_argument('--master_port', type=int, default=None,
                        help='Master port for distributed communication (default: auto-select, avoids port conflicts)')
    parser.add_argument('--use_calibration', action='store_true', default=False,
                        help='Use calibration set to compute thresholds via hypothesis testing (default: False)')
    parser.add_argument('--calibration_ratio', type=float, default=0.3,
                        help='Ratio of calibration set (default: 0.3, i.e., 3:7 split)')
    parser.add_argument('--significant_level', type=float, default=0.1,
                        help='Significant level for hypothesis testing (default: 0.10)')
    parser.add_argument('--random_seed', type=int, default=42,
                        help='Random seed for data splitting (default: 42)')
    
    args = parser.parse_args()
    
    args.model_name = os.path.basename(args.model_path)
    
    # 设置分布式通信端口（避免端口冲突）
    if args.master_port is not None:
        os.environ['MASTER_PORT'] = str(args.master_port)
        print(f"Using master port: {args.master_port}")
    else:
        # 自动选择一个未被占用的端口
        import socket
        def find_free_port():
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('', 0))
                return s.getsockname()[1]
        free_port = find_free_port()
        os.environ['MASTER_PORT'] = str(free_port)
        print(f"Auto-selected master port: {free_port}")
    
    # 设置其他 NCCL 环境变量以避免通信问题
    os.environ.setdefault('NCCL_DEBUG', 'INFO')  # 可选：设置为 WARN 减少日志
    os.environ.setdefault('NCCL_SOCKET_IFNAME', 'lo')  # 使用 loopback 接口（单机多卡）
    
    # 初始化模型
    print(f"Loading model: {args.model_path}")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # 诊断 GPU 信息
    num_gpus = torch.cuda.device_count()
    print(f"Detected {num_gpus} GPU(s)")
    for i in range(num_gpus):
        props = torch.cuda.get_device_properties(i)
        print(f"  GPU {i}: {props.name}, Memory: {props.total_memory / 1024**3:.1f} GB")
    
    # 检查是否有其他进程占用 GPU
    import subprocess
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=index,memory.used,utilization.gpu', 
                                '--format=csv,noheader'], 
                               capture_output=True, text=True, timeout=5)
        print("\nCurrent GPU status:")
        print(result.stdout)
    except:
        pass
    
    # 设置 tensor_parallel_size
    if args.tensor_parallel_size is not None:
        tensor_parallel = args.tensor_parallel_size
        if tensor_parallel > num_gpus:
            print(f"Warning: tensor_parallel_size ({tensor_parallel}) > available GPUs ({num_gpus}), using {num_gpus}")
            tensor_parallel = num_gpus
    else:
        tensor_parallel = num_gpus  # 默认使用所有可用 GPU
    
    print(f"\nUsing tensor_parallel_size={tensor_parallel}")
    
    llm = LLM(
        model=args.model_path,
        tokenizer=args.model_path,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=tensor_parallel,
        enable_prefix_caching=True,
        swap_space=16,
        max_num_seqs=args.max_num_seqs,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    
    # 如果使用 transformers 进行 forced decoding
    model = None
    if args.use_transformers:
        print("Loading transformers model for forced decoding...")
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            device_map='auto'
        )
        model.eval()
    
    if args.configs is None:
        args.configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']
    
    # 处理每个配置
    for config in tqdm(args.configs, desc="Processing configs"):
        print(f"\n{'='*60}")
        print(f"Processing {config} dataset")
        print(f"{'='*60}")
        
        # 创建输出目录
        output_dir = os.path.join(args.output_dir, f'{args.model_name}_uncertainty_drop')
        os.makedirs(output_dir, exist_ok=True)
        
        # 加载数据集
        print(f"Loading ProcessBench {config} dataset...")
        input_data = load_dataset('Qwen/ProcessBench', split=config)
        print(f"Loaded {len(input_data)} samples")
        
        # 数据划分：calibration set 和 test set
        if args.use_calibration:
            # 转换为列表以便划分
            data_list = list(input_data)
            indices = list(range(len(data_list)))
            
            # 划分 calibration set 和 test set (3:7)
            cal_indices, test_indices = train_test_split(
                indices, 
                test_size=1 - args.calibration_ratio,
                random_state=args.random_seed,
                shuffle=True
            )
            
            calibration_data = [data_list[i] for i in cal_indices]
            test_data = [data_list[i] for i in test_indices]
            
            print(f"\nData split:")
            print(f"  Calibration set: {len(calibration_data)} samples ({len(calibration_data)/len(data_list)*100:.1f}%)")
            print(f"  Test set: {len(test_data)} samples ({len(test_data)/len(data_list)*100:.1f}%)")
            
            # 处理 calibration set
            print(f"\nProcessing calibration set...")
            cal_labels = []
            cal_uncertainties = []
            
            for idx, sample in enumerate(tqdm(calibration_data, desc="Calibration set")):
                problem = sample['problem']
                steps = sample['steps']
                label = sample['label']
                
                actual_use_drop = args.use_drop and not args.drop_only
                paragraph_uncertainties, _ = process_sample_with_uncertainty_drop(
                    llm, tokenizer, problem, steps,
                    ema_span=args.ema_span,
                    drop_k=args.drop_k,
                    use_transformers=args.use_transformers,
                    model=model,
                    device=device,
                    uncertainty_measures=args.uncertainty_measures,
                    use_drop=actual_use_drop,
                    batch_paragraphs=args.batch_paragraphs
                )
                
                cal_labels.append(label)
                cal_uncertainties.append(paragraph_uncertainties)
            
            # 获取所有可用的 measures
            available_measures = []
            if cal_uncertainties and len(cal_uncertainties) > 0:
                first_sample = cal_uncertainties[0]
                if first_sample and len(first_sample) > 0:
                    first_para = first_sample[0]
                    if isinstance(first_para, dict):
                        available_measures = list(first_para.keys())
            
            print(f"\nFound {len(available_measures)} uncertainty measures: {available_measures}")
            
            # 在 calibration set 上计算每个 measure 的阈值
            print(f"\nComputing thresholds from calibration set (significant_level={args.significant_level})...")
            print(f"Using CORRECT steps distribution to set thresholds (hypothesis testing approach)")
            thresholds = {}
            
            for measure_name in available_measures:
                # 收集错误步骤的 uncertainty 值
                error_uncertainties = collect_error_step_uncertainties(
                    cal_uncertainties, cal_labels, measure_name
                )
                
                if len(error_uncertainties) > 0:
                    threshold = calculate_threshold_from_calibration(
                        error_uncertainties, measure_name, 
                        significant_level=args.significant_level
                    )
                    thresholds[measure_name] = threshold
                    print(f"  {measure_name}: threshold = {threshold:.6f} (from {len(error_uncertainties)} error steps)")
                else:
                    thresholds[measure_name] = None
                    print(f"  {measure_name}: No error samples found, threshold = None")
            
            # 保存阈值信息
            thresholds_path = os.path.join(output_dir, f'{config}_thresholds.json')
            with open(thresholds_path, 'w') as f:
                json.dump(thresholds, f, indent=2)
            print(f"\nThresholds saved to: {thresholds_path}")
            
            # 处理 test set
            print(f"\nProcessing test set...")
            test_labels = []
            test_uncertainties = []
            test_results = []
            
            for idx, sample in enumerate(tqdm(test_data, desc="Test set")):
                problem = sample['problem']
                steps = sample['steps']
                label = sample['label']
                
                actual_use_drop = args.use_drop and not args.drop_only
                paragraph_uncertainties, _ = process_sample_with_uncertainty_drop(
                    llm, tokenizer, problem, steps,
                    ema_span=args.ema_span,
                    drop_k=args.drop_k,
                    use_transformers=args.use_transformers,
                    model=model,
                    device=device,
                    uncertainty_measures=args.uncertainty_measures,
                    use_drop=actual_use_drop,
                    batch_paragraphs=args.batch_paragraphs
                )
                
                test_labels.append(label)
                test_uncertainties.append(paragraph_uncertainties)
                
                result_item = {
                    'id': sample.get('id', idx),
                    'problem': problem,
                    'steps': steps,
                    'label': label,
                    'paragraph_uncertainties': paragraph_uncertainties
                }
                test_results.append(result_item)
            
            # 使用阈值在 test set 上进行预测和评估
            all_measure_results = {}
            
            print(f"\nGenerating uncertainty distribution plots...")
            for measure_name in available_measures:
                print(f"\nEvaluating with measure: {measure_name}...")
                threshold = thresholds.get(measure_name, None)
                
                if threshold is not None:
                    print(f"  Using threshold: {threshold:.6f}")
                    # 使用阈值进行预测
                    predictions = predict_error_with_measure(test_uncertainties, measure_name, threshold=threshold)
                else:
                    print(f"  Warning: No threshold available, using mean±std method")
                    # 回退到 mean±std 方法
                    predictions = predict_error_with_measure(test_uncertainties, measure_name, threshold=None)
                
                # 评估该 measure 的性能
                eval_results = evaluate_predictions(
                    predictions, test_labels, test_uncertainties,
                    output_dir, config, measure_name=measure_name, threshold=threshold
                )
                
                all_measure_results[measure_name] = {
                    'predictions': predictions,
                    'metrics': eval_results,
                    'threshold': threshold
                }
                
                # 打印结果
                print(f"  Overall Accuracy: {eval_results['overall_accuracy']:.2f}%")
                print(f"  Error Accuracy: {eval_results['error_accuracy']:.2f}%")
                print(f"  Correct Accuracy: {eval_results['correct_accuracy']:.2f}%")
                print(f"  F1 Score: {eval_results['f1_score']:.2f}%")
                print(f"  Error Position Accuracy: {eval_results['error_position_accuracy']:.2f}%")
                print(f"  AUC-ROC: {eval_results['auc_roc']:.4f}")
                print(f"  AUC-PR: {eval_results['auc_pr']:.4f}")
                
                # 收集所有步骤的 uncertainty 值用于可视化
                cal_error_unc, cal_correct_unc = collect_all_step_uncertainties(
                    cal_uncertainties, cal_labels, measure_name
                )
                test_error_unc, test_correct_unc = collect_all_step_uncertainties(
                    test_uncertainties, test_labels, measure_name
                )
                
                # 绘制分布图
                plot_uncertainty_distributions(
                    cal_error_unc, cal_correct_unc,
                    test_error_unc, test_correct_unc,
                    threshold, measure_name, output_dir, config
                )
            
            # 保存结果
            all_labels = test_labels
            all_uncertainties = test_uncertainties
            all_results = test_results
            
        else:
            # 原来的流程：不使用 calibration set
            print(f"\nUsing mean±std method (no calibration set)")
            
            # 处理每个样本
            all_predictions = []
            all_labels = []
            all_uncertainties = []
            all_results = []
            
            for idx, sample in enumerate(tqdm(input_data, desc=f"Processing {config}")):
                problem = sample['problem']
                steps = sample['steps']
                label = sample['label']
                
                # 计算每段的多种 uncertainty measures（使用 forced decoding）
                # 如果 use_drop=True，会同时计算 drop 和原始版本，得到 N*2 个 measures
                # 如果 drop_only=True，只计算 drop 版本
                actual_use_drop = args.use_drop and not args.drop_only
                paragraph_uncertainties, predicted_error_idx = process_sample_with_uncertainty_drop(
                    llm, tokenizer, problem, steps,
                    ema_span=args.ema_span,
                    drop_k=args.drop_k,
                    use_transformers=args.use_transformers,
                    model=model,
                    device=device,
                    uncertainty_measures=args.uncertainty_measures,
                    use_drop=actual_use_drop,
                    batch_paragraphs=args.batch_paragraphs
                )
                
                all_labels.append(label)
                all_uncertainties.append(paragraph_uncertainties)
                
                # 保存详细结果（包含所有 measures 的原始数据）
                result_item = {
                    'id': sample.get('id', idx),
                    'problem': problem,
                    'steps': steps,
                    'label': label,
                    'paragraph_uncertainties': paragraph_uncertainties
                }
                all_results.append(result_item)
            
            # 获取所有可用的 measures
            available_measures = []
            if all_uncertainties and len(all_uncertainties) > 0:
                first_sample = all_uncertainties[0]
                if first_sample and len(first_sample) > 0:
                    first_para = first_sample[0]
                    if isinstance(first_para, dict):
                        available_measures = list(first_para.keys())
            
            print(f"\nFound {len(available_measures)} uncertainty measures: {available_measures}")
            
            # 对每个 measure 单独进行评估
            all_measure_results = {}
            
            for measure_name in available_measures:
                print(f"\nEvaluating with measure: {measure_name}...")
                
                # 使用该 measure 预测错误位置（使用 mean±std 方法）
                predictions = predict_error_with_measure(all_uncertainties, measure_name, threshold=None)
                
                # 评估该 measure 的性能
                eval_results = evaluate_predictions(
                    predictions, all_labels, all_uncertainties,
                    output_dir, config, measure_name=measure_name, threshold=None
                )
                
                all_measure_results[measure_name] = {
                    'predictions': predictions,
                    'metrics': eval_results
                }
                
                # 打印结果
                print(f"  Overall Accuracy: {eval_results['overall_accuracy']:.2f}%")
                print(f"  Error Accuracy: {eval_results['error_accuracy']:.2f}%")
                print(f"  Correct Accuracy: {eval_results['correct_accuracy']:.2f}%")
                print(f"  F1 Score: {eval_results['f1_score']:.2f}%")
                print(f"  Error Position Accuracy: {eval_results['error_position_accuracy']:.2f}%")
                print(f"  AUC-ROC: {eval_results['auc_roc']:.4f}")
                print(f"  AUC-PR: {eval_results['auc_pr']:.4f}")
        
        # 保存每个 measure 的结果
        results_summary = {}
        for measure_name, measure_data in all_measure_results.items():
            measure_summary = measure_data['metrics'].copy()
            if 'threshold' in measure_data:
                measure_summary['threshold'] = measure_data['threshold']
            results_summary[measure_name] = measure_summary
        
        # 添加配置信息
        results_summary['config'] = {
            'use_calibration': args.use_calibration,
            'calibration_ratio': args.calibration_ratio if args.use_calibration else None,
            'significant_level': args.significant_level if args.use_calibration else None,
            'random_seed': args.random_seed if args.use_calibration else None
        }
        
        suffix = '_calibration' if args.use_calibration else ''
        results_path = os.path.join(output_dir, f'{config}_all_measures_results{suffix}.json')
        with open(results_path, 'w') as f:
            json.dump(results_summary, f, indent=2)
        print(f"\nAll measures results saved to: {results_path}")
        
        # 保存详细结果（包含每个 measure 的预测）
        for idx, result_item in enumerate(all_results):
            result_item['predictions'] = {}
            for measure_name, measure_data in all_measure_results.items():
                result_item['predictions'][measure_name] = measure_data['predictions'][idx]
                result_item[f'match_{measure_name}'] = (measure_data['predictions'][idx] == result_item['label'])
        
        detailed_results_path = os.path.join(output_dir, f'{config}_detailed_results{suffix}.jsonl')
        with open(detailed_results_path, 'w') as f:
            for result in all_results:
                f.write(json.dumps(result, ensure_ascii=False) + '\n')
        
        print(f"Detailed results saved to: {detailed_results_path}")


if __name__ == '__main__':
    main()

