import logging
import os
from datetime import datetime
import numpy as np
import pandas as pd
import openai
import base64
import json
import re
from tqdm import tqdm
import matplotlib.pyplot as plt
import multiprocessing
from functools import partial
from collections import Counter, defaultdict
from scipy.stats import spearmanr
import ast
import time
from difflib import SequenceMatcher
import argparse
from sklearn.metrics import f1_score

# 新增：tiktoken 用于token统计
try:
    import tiktoken
    _tiktoken_available = True
except ImportError:
    _tiktoken_available = False
    print("警告：未安装 tiktoken，无法统计token数。请 pip install tiktoken")
    

def compute_binary_ece(probs, pre_labels, gt_labels, n_bins=10):
    """
    计算二分类模型的 Expected Calibration Error (ECE)
    
    参数：
        probs: ndarray, 形状为 (N,)，每个样本预测为正类（1）的概率
        pre_labels: ndarray, 形状为 (N,)，模型预测的标签（0或1）
        gt_labels: ndarray, 形状为 (N,)，真实标签（0或1）
        n_bins: int，置信度区间划分数量（默认10）
        
    返回：
        ece: float，Expected Calibration Error
    """
    # 确保输入为 numpy 数组
    probs = np.asarray(probs)
    pre_labels = np.asarray(pre_labels)
    gt_labels = np.asarray(gt_labels)
    
    # 1. 计算置信度（模型预测标签对应的概率）
    # 当预测标签为1时，置信度 = probs
    # 当预测标签为0时，置信度 = 1 - probs
    confidences = np.where(pre_labels == 1, probs, 1 - probs)
    
    # 2. 计算预测是否正确
    accuracies = (pre_labels == gt_labels).astype(float)
    
    # 3. 创建等宽置信度区间
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    total_samples = len(probs)
    
    # 4. 遍历每个置信度区间
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        
        # 确定落在当前区间的样本（最后一个区间包含上边界）
        if i == n_bins - 1:
            in_bin = (confidences >= bin_lower) & (confidences <= bin_upper)
        else:
            in_bin = (confidences >= bin_lower) & (confidences < bin_upper)

        bin_count = np.sum(in_bin)
        
        # 只处理包含样本的区间
        if bin_count > 0:
            # 计算区间内的平均准确率
            bin_acc = np.mean(accuracies[in_bin])
            
            # 计算区间内的平均置信度
            bin_conf = np.mean(confidences[in_bin])
            
            # 计算当前区间的校准误差并累加
            ece += (bin_count / total_samples) * np.abs(bin_conf - bin_acc)
    
    return ece

# 在 compute_binary_ece 函数后添加以下函数

def compute_classwise_ece(step_details, n_bins=10):
    """
    分别对正类(正确步骤)和负类(错误步骤)样本独立计算 ECE
    
    参数:
        step_details: 步骤详细信息列表
        n_bins: 分桶数量
        
    返回:
        ece_pos: 正类(正确步骤)的 ECE
        ece_neg: 负类(错误步骤)的 ECE
    """
    # 提取正类样本 (gt_label=1)
    pos_steps = [step for step in step_details if step["gt_label"] == 1]
    pos_probs = [step["stdconf"] for step in pos_steps]
    pos_pre_labels = [step["pre_label"] for step in pos_steps]
    pos_gt_labels = [step["gt_label"] for step in pos_steps]
    
    # 提取负类样本 (gt_label=0)
    neg_steps = [step for step in step_details if step["gt_label"] == 0]
    neg_probs = [step["stdconf"] for step in neg_steps]
    neg_pre_labels = [step["pre_label"] for step in neg_steps]
    neg_gt_labels = [step["gt_label"] for step in neg_steps]
    
    # 计算正类 ECE
    ece_pos = compute_binary_ece(
        np.array(pos_probs),
        np.array(pos_pre_labels),
        np.array(pos_gt_labels),
        n_bins=n_bins
    ) if pos_probs else np.nan
    
    # 计算负类 ECE
    ece_neg = compute_binary_ece(
        np.array(neg_probs),
        np.array(neg_pre_labels),
        np.array(neg_gt_labels),
        n_bins=n_bins
    ) if neg_probs else np.nan
    
    return ece_pos, ece_neg

def compute_per_class_accuracy(step_details):
    """
    分别计算正类和负类上的准确率
    
    参数:
        step_details: 步骤详细信息列表
        
    返回:
        acc_pos: 正类(正确步骤)的准确率
        acc_neg: 负类(错误步骤)的准确率
    """
    # 提取正类样本 (gt_label=1)
    pos_steps = [step for step in step_details if step["gt_label"] == 1]
    pos_correct = [step for step in pos_steps if step["pre_label"] == step["gt_label"]]
    
    # 提取负类样本 (gt_label=0)
    neg_steps = [step for step in step_details if step["gt_label"] == 0]
    neg_correct = [step for step in neg_steps if step["pre_label"] == step["gt_label"]]
    
    # 计算准确率
    acc_pos = len(pos_correct) / len(pos_steps) if pos_steps else np.nan
    acc_neg = len(neg_correct) / len(neg_steps) if neg_steps else np.nan
    
    return acc_pos, acc_neg

def compute_adaptive_binning_ece(probs, pre_labels, gt_labels, n_bins=10):
    """
    使用等频分桶计算 ECE
    
    参数:
        probs: 置信度数组
        pre_labels: 预测标签数组
        gt_labels: 真实标签数组
        n_bins: 分桶数量
        
    返回:
        ece: 等频分桶的 ECE
    """
    # 确保输入为 numpy 数组
    probs = np.asarray(probs)
    pre_labels = np.asarray(pre_labels)
    gt_labels = np.asarray(gt_labels)
    
    # 计算置信度（模型预测标签对应的概率）
    confidences = np.where(pre_labels == 1, probs, 1 - probs)
    
    # 计算预测是否正确
    accuracies = (pre_labels == gt_labels).astype(float)
    
    # 创建等频置信度区间
    sorted_indices = np.argsort(confidences)
    sorted_conf = confidences[sorted_indices]
    sorted_acc = accuracies[sorted_indices]
    
    bin_size = len(confidences) // n_bins
    ece = 0.0
    
    for i in range(n_bins):
        start_idx = i * bin_size
        end_idx = (i + 1) * bin_size if i < n_bins - 1 else len(confidences)
        
        bin_conf = sorted_conf[start_idx:end_idx]
        bin_acc = sorted_acc[start_idx:end_idx]
        
        bin_count = len(bin_conf)
        if bin_count > 0:
            bin_conf_mean = np.mean(bin_conf)
            bin_acc_mean = np.mean(bin_acc)
            ece += (bin_count / len(confidences)) * np.abs(bin_conf_mean - bin_acc_mean)
    
    return ece

def build_messages(problem: str, final_answer, process_evaluation, student_solution) -> list:
    """
    构建对话历史消息列表，包含系统提示、示例和当前问题（使用科学类 step-by-step correctness 评估格式）
    """

    # 系统消息（新版）
    system_message = {
    "role": "system",
    "content": (
        "You are a teacher skilled in evaluating the intermediate steps of a student's solution to a given problem.\n\n"
        "You are given a scientific problem, its correct final answer, and a student's step-by-step solution.\n\n"
        "Your task is as follows:\n"
        "1. Carefully solve the problem yourself, using the correct final answer as a hint to guide you to a consistent, correct reasoning path.\n"
        "2. Then, evaluate the **correctness of each step** in the student's solution.\n\n"
        "Focus exclusively on the scientific, logical, or mathematical correctness of the solution. Ignore differences in formatting, expression style, specific wording, or presentation order, as long as the reasoning and results are valid.\n\n"
        "For **each step**, output:\n\n"
        "- The full original step (as a string)\n"
        "- A correctness label:  \n"
        "  - `1` if the model believes the step is correct (i.e., if `P(correct) ≥ 0.5`)  \n"
        "  - `0` otherwise\n"
        "- A probability estimate `P(correct)` ∈ (0, 1), representing the model's assessment of the likelihood that the step is correct (`correctness label = 1`)\n"
        "- If the step is incorrect (`correctness label = 0`), also provide:\n"
        "  - An error category (from the list below)\n\n"
        "**Error categories (for incorrect steps):**\n"
        "- Numerical Calculation Error\n"
        "- Symbolic Calculation Error\n"
        "- Visual Interpretation Error\n"
        "- Reasoning Error\n"
        "- Knowledge Error\n"
        "- Question Understanding Error\n"
        "- No solution provided\n\n"
        "## Output Format:\n"
        "Wrap your output in this Python list format (and nothing else), enclosed by `<evaluation>` and `</evaluation>` tags:\n\n"
        "<evaluation>  \n"
        "[\n"
        "  [\"Step 1: ...\", correctness_label, P_correct, \"Error type if incorrect\"],\n"
        "  ...\n"
        "]\n"
        "</evaluation>\n\n"
        "**Important Requirements:**\n"
        "- You must return one and only one evaluation entry **per step** in the student's solution.\n"
        "- The number of output entries **must exactly match the number of steps** (e.g., if the student has 15 steps, your output list must contain 15 entries).\n"
        "- Do not skip, merge, or summarize steps.\n"
        "- If the step is correct, use an empty string for the error type: `\"\"`."
        "- Keep each step as a single complete unit, even if it contains multiple sentences."
        "- Please evaluate each step one by one. Every step must be assessed and scored individually, even if it is very short. Do not merge, omit, or skip any steps."
    )
}

    # 示例 1
    example1_user = {
        "role": "user",
        "content": (
            # "User 3  \n"
            "Question:  \n"
            "b. Next, imagine a long row of dominoes with equal spacing $l$ between the nearest sides of any pair of adjacent dominos, as shown above. "
            "When a domino topples, it collides with the next domino in the row. Imagine this collision to be completely inelastic. "
            "What fraction of the total kinetic energy is lost in the collision of the first domino with the second domino?\n"
            "Solution:\n"
            "Step 1: To solve this problem, we need to consider the dynamics of the dominoes in the row and the energy conservation during the collisions. Let's break down the problem step by step.\n"
            "Step 2: a. i. To topple the domino, we need to apply an impulse that overcomes the gravitational force acting on the domino. "
            "The impulse required is given by the product of the mass of the domino and the velocity at which it is pushed. The velocity at which the domino is pushed is related to the height $H$ above the table where the impulse is required to topple the domino. Using the equation of motion for the domino, we have:\n"
            "Step 3: $$m \\cdot \\frac{d^2H}{dt^2} = -mg \\cdot \\frac{dH}{dt}$$\n"
            "Step 4: where $m$ is the mass of the domino, $g$ is the acceleration due to gravity, and $H$ is the height above the table.\n"
            "Step 5: Solving this differential equation, we get:\n"
            "Step 6: $$H(t) = H_0 \\cdot \\cosh(\\sqrt{\\frac{g}{m}}t)$$\n"
            "Step 7: where $H_0$ is the initial height of the domino.\n"
            "Step 8: To topple the domino, we need to find the smallest height $H$ at which the impulse is required. This occurs when the velocity of the domino is maximum, which is given by:\n"
            "Step 9: $$v = \\sqrt{g \\cdot \\left(\\sqrt{t^2 + H^2} - H\\right)}$$\n"
            "Step 10: ii. The minimum value of $p$ to topple the domino is given by:\n"
            "Step 11: $$p_{\\min} = \\frac{1}{\\sqrt{3}} \\frac{m}{H} \\sqrt{g\\left(\\sqrt{t^2 + H^2} - H\\right)\\left(H^2 + t^2\\right)}$$\n"
            "Step 12: b. When a domino topples, it collides with the next domino in the row. Since the collision is completely inelastic, the total kinetic energy is conserved. "
            "Let's denote the kinetic energy of the first domino as $E_1$ and the kinetic energy of the second domino as $E_2$. The total kinetic energy is given by:\n"
            "Step 13: $$E_{\\text{total}} = E_1 + E_2$$\n"
            "Step 14: where $E_1 = \\frac{1}{2}mv^2$ and $E_2 = \\frac{1}{2}mv^2$.\n"
            "Step 15: The fraction of the total kinetic energy lost in the collision is given by:\n"
            "Step 16: $$\\frac{E_2 - E_1}{E_{\\text{total}}} = \\frac{\\frac{1}{2}mv^2 - \\frac{1}{2}mv^2}{\\frac{1}{2}mv^2 + \\frac{1}{2}mv^2} = \\frac{0}{\\frac{1}{2}mv^2} = 0$$\n"
            "Step 17: Therefore, the fraction of the total kinetic energy lost in the collision of the first domino with the second domino is 0.\n"
            "Step 18: The final answer is \\boxed{0}.\n"
            "Answer: 50"
        )
    }

    example1_assistant = {
        "role": "assistant",
        "content": (
            "<evaluation>\n"
            "[\n"
            "  [\"Step 1: To solve this problem, we need to consider the dynamics of the dominoes in the row and the energy conservation during the collisions. Let's break down the problem step by step.\", 1, 0.9, \"\"],\n"
            "  [\"Step 2: a. i. To topple the domino, we need to apply an impulse that overcomes the gravitational force acting on the domino. The impulse required is given by the product of the mass of the domino and the velocity at which it is pushed. The velocity at which the domino is pushed is related to the height $H$ above the table where the impulse is required to topple the domino. Using the equation of motion for the domino, we have:\", 0, 0.3, \"Reasoning Error\"],\n"
            "  [\"Step 3: $$m \\\\cdot \\\\frac{d^2H}{dt^2} = -mg \\\\cdot \\\\frac{dH}{dt}$$\", 0, 0.2, \"Reasoning Error\"],\n"
            "  [\"Step 4: where $m$ is the mass of the domino, $g$ is the acceleration due to gravity, and $H$ is the height above the table.\", 1, 0.95, \"\"],\n"
            "  [\"Step 5: Solving this differential equation, we get:\", 0, 0.3, \"Reasoning Error\"],\n"
            "  [\"Step 6: $$H(t) = H_0 \\\\cdot \\\\cosh(\\\\sqrt{\\\\frac{g}{m}}t)$$\", 0, 0.1, \"Reasoning Error\"],\n"
            "  [\"Step 7: where $H_0$ is the initial height of the domino.\", 1, 0.95, \"\"],\n"
            "  [\"Step 8: To topple the domino, we need to find the smallest height $H$ at which the impulse is required. This occurs when the velocity of the domino is maximum, which is given by:\", 0, 0.3, \"Reasoning Error\"],\n"
            "  [\"Step 9: $$v = \\\\sqrt{g \\\\cdot \\\\left(\\\\sqrt{t^2 + H^2} - H\\\\right)}$$\", 0, 0.2, \"Reasoning Error\"],\n"
            "  [\"Step 10: ii. The minimum value of $p$ to topple the domino is given by:\", 0, 0.3, \"Reasoning Error\"],\n"
            "  [\"Step 11: $$p_{\\\\min} = \\\\frac{1}{\\\\sqrt{3}} \\\\frac{m}{H} \\\\sqrt{g\\\\left(\\\\sqrt{t^2 + H^2} - H\\\\right)\\\\left(H^2 + t^2\\\\right)}$$\", 0, 0.2, \"Reasoning Error\"],\n"
            "  [\"Step 12: b. When a domino topples, it collides with the next domino in the row. Since the collision is completely inelastic, the total kinetic energy is conserved. Let's denote the kinetic energy of the first domino as $E_1$ and the kinetic energy of the second domino as $E_2$. The total kinetic energy is given by:\", 0, 0.2, \"Reasoning Error\"],\n"
            "  [\"Step 13: $$E_{\\\\text{total}} = E_1 + E_2$$\", 1, 0.85, \"\"],\n"
            "  [\"Step 14: where $E_1 = \\\\frac{1}{2}mv^2$ and $E_2 = \\\\frac{1}{2}mv^2$.\", 0, 0.1, \"Reasoning Error\"],\n"
            "  [\"Step 15: The fraction of the total kinetic energy lost in the collision is given by:\", 1, 0.8, \"\"],\n"
            "  [\"Step 16: $$\\\\frac{E_2 - E_1}{E_{\\\\text{total}}} = \\\\frac{\\\\frac{1}{2}mv^2 - \\\\frac{1}{2}mv^2}{\\\\frac{1}{2}mv^2 + \\\\frac{1}{2}mv^2} = \\\\frac{0}{\\\\frac{1}{2}mv^2} = 0$$\", 0, 0.1, \"Reasoning Error\"],\n"
            "  [\"Step 17: Therefore, the fraction of the total kinetic energy lost in the collision of the first domino with the second domino is 0.\", 0, 0.15, \"Reasoning Error\"],\n"
            "  [\"Step 18: The final answer is \\\\boxed{0}.\", 0, 0.1, \"Reasoning Error\"]\n"
            "]\n"
            "</evaluation>"
        )
    }

    # 处理final_answer为字符串
    if isinstance(final_answer, (list, tuple, np.ndarray)):
        final_answer_str = "\n".join([str(ans) for ans in final_answer])
    else:
        final_answer_str = str(final_answer)

    # 提取解题步骤
    solution_steps = []
    if isinstance(process_evaluation, (list, np.ndarray)):
        for i, item in enumerate(process_evaluation):
            if isinstance(item, np.ndarray) and len(item) > 0:
                step_text = str(item[0])
                solution_steps.append(f"Step {i+1}: {step_text}")
                
    elif isinstance(student_solution, (list, np.ndarray)):
        for i, step in enumerate(student_solution):
            step_text = str(step)
            solution_steps.append(f"Step {i+1}: {step_text}")
    else:
        # 记录错误：solution_steps = []
        logging.warning("solution_steps 为空，未能从 process_evaluation 或 student_solution 提取到解题步骤。")

    solution_text = "\n".join(solution_steps)
    if not solution_text:
        raise ValueError("solution_text is empty.")

    current_user = {
        "role": "user",
        "content": (
            f"Question: {problem}\n"
            f"Solution:\n{solution_text}\n"
            f"Answer: {final_answer_str}"
        )
    }

    # 组合消息历史
    messages = [
        system_message,
        example1_user,
        example1_assistant,
        # example2_user,
        # example2_assistant,
        # example3_user,
        # example3_assistant,
        current_user
    ]

    return messages
        
def escape_illegal_backslashes(json_str):
    """
    将非法的单个反斜杠修复为两个反斜杠（避免 \s、\c 等 LaTeX 写法出错）
    """
    return re.sub(r'\\(?![\\\"/bfnrtu])', r'\\\\', json_str)

def clean_and_parse_json(raw_output):
    """
    提取 <evaluation> 标签内容，清洗并解析 JSON，处理非法反斜杠和尾部逗号
    """
    # 匹配形如 ["Step 70: ...", 0/1, float, "str"] 的结构
    # pattern = r'\[\s*"([^"]*Step\s*\d+:[^"]*)"\s*,\s*(\d+)\s*,\s*([0-9.]+)\s*,\s*"([^"]*)"\s*\]'
    pattern = r'\[\s*"([^"]+)"\s*,\s*(0|1)\s*,\s*([0-9.]+)\s*,\s*"([^"]*)"\s*\]'
    matches = re.findall(pattern, raw_output)

    res = []
    for m in matches:
        # m: (step_text, label, conf, reason)
        try:
            label = int(m[1])
            conf = float(m[2])
            reason = m[3]
            res.append([m[0], label, conf, reason])
        except Exception as e:
            logging.warning(f"正则提取后解析失败: {e}, match={m}")

    if not res:
        # logging.warning("未能从content中提取到任何['Step ...', 0/1, float, 'str']结构")
        logging.warning("未能从content中提取到任何['Step ...', 0/1, float, 'str']结构")
        logging.error(f"raw_output: {raw_output}")  
        return None
    
    return res
        
def remove_first_string_field_from_json_array(json_like_str: str) -> str:
    """
    删除 JSON 数组中每个列表项的第一个字符串字段（通常是 Step 描述）。
    如果第一个text是形如["Step 9:      \\]"]或["Step 9:      \\\\"]，则删除整个该项。
    例如：["Step 9:      \\\\", 0, 0.1, "Reasoning Error"]会被整个移除。
    """

    # 匹配所有形如 ["xxx", 0/1, float, "str"] 的项
    pattern = r'\[\s*"((?:[^"\\]|\\.)*)"\s*,\s*([^,\[\]]+)\s*,\s*([^,\[\]]+)\s*,\s*("[^"]*")\s*\]'

    def repl(match):
        first_str = match.group(1)
        # 去除前后空格
        first_str_strip = first_str.strip()
        # 检查是否为 Step x:      \] 或 Step x:      \\] 形式
        # 允许任意空格，允许单斜杠或双斜杠
        if re.match(r'^Step\s*\d+:\s*\\\\*]$', first_str_strip):
            # 整个项删除
            return ''
        else:
            # 正常去掉第一个字符串字段
            return f'[{match.group(2)}, {match.group(3)}, {match.group(4)}]'

    # 先做一次替换，去掉需要删除的项和去掉第一个text
    cleaned = re.sub(pattern, repl, json_like_str)

    # 再去掉多余的逗号（可能出现连续逗号的情况），以及多余的逗号导致的语法问题
    # 例如: [[1, 0.9, ""], , [0, 0.1, "Reasoning Error"]]
    cleaned = re.sub(r',\s*,', ',', cleaned)
    cleaned = re.sub(r'\[\s*,', '[', cleaned)
    cleaned = re.sub(r',\s*\]', ']', cleaned)
    return cleaned


def call_gpt4o_multimodal(image_paths, messages, client, model_name="gpt-4o", max_tokens=5000):
    """
    支持多图输入，image_paths为list[str]或None
    model_name: 可选 "gpt-4o", "gpt-4-1106-preview", "gemini-2.5-flash", "gemini-2.0-flash"
    """
    # 保持openai标准格式
    if image_paths:
        messages = messages.copy()
        last_user_index = -1
        for i, msg in enumerate(messages):
            if msg["role"] == "user":
                last_user_index = i

        if last_user_index >= 0:
            new_content = [{"type": "text", "text": messages[last_user_index]["content"]}]
            for img_path in image_paths:
                if img_path and isinstance(img_path, str) and os.path.exists(img_path):
                    with open(img_path, "rb") as f:
                        base64_image = base64.b64encode(f.read()).decode("utf-8")
                    new_content.append({
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    })
            messages[last_user_index] = {
                "role": "user",
                "content": new_content
            }
    # 支持多模型选择
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0,
        max_tokens=max_tokens,
    )
    # response = client.chat.completions.create(
    # model=model_name,
    # messages=[{ "role": "user", "content": "请问你是gpt吗？你是什么啊，你叫什么名字" }],
    # temperature=0,
    # max_tokens=4096,
    # )

    return response.choices[0].message.content

def call_gpt4o_text(messages, client, model_name="gpt-4o", max_tokens=5000):
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0,
        max_tokens=max_tokens,
    )
    return response.choices[0].message.content

def score_to_confidence_and_label(score, threshold=0.0):
    """
    将评分转换为置信度和标签
    score: -1到1之间的评分
    threshold: 判断为正类（正确）的阈值
    返回: (confidence, label)
    """
    # 将评分归一化到[0,1]区间作为置信度
    confidence = (score + 1) / 2
    
    # 根据阈值判断标签
    label = 1 if score > threshold else 0
    
    return confidence, label

def extract_gt_labels(process_evaluation):
    """
    从process_evaluation字段提取真实标签
    """
    labels = []
    for item in process_evaluation:
        try:
            label = int(float(item[1]))
            labels.append(label)
        except Exception:
            labels.append(0)
    return labels

def extract_confidence_and_label(evaluation_list, threshold=0.5, sample_id=None):
    """
    从LLM输出的evaluation列表中提取置信度和预测标签，并检查置信度与正确性标记是否一致。
    item格式: [step_text, correctness_mark(1/0), confidence_score(0~1), ...]
    返回: confs, preds
    新增: 报错时记录样本id字段
    """
    confs = []
    preds = []
    for idx, item in enumerate(evaluation_list):
        try:
            # 置信度在第3个元素
            conf = float(item[1])
            # 正确性标记在第2个元素
            correct_mark = int(item[0])
            # 置信度大于阈值时，预测为正确(1)，否则为错误(0)
            pred = 1 if conf > threshold else 0
            # 检查置信度和正确性标记是否一致
            # 新增规则：置信度等于阈值时，无论预测为1还是0，都认为与正确性标记一致
            if conf == threshold:
                confs.append(conf)
                preds.append(correct_mark)
            elif pred != correct_mark:
                # pred与correct_mark不一致时，默认以correct_mark为准，并将conf重置为1-conf
                logging.warning(f"警告：样本id={sample_id} 第{idx+1}步置信度({conf})与正确性标记({correct_mark})不一致，已自动修正。item: {item}")
                conf_fixed = 1 - conf
                confs.append(conf_fixed)
                preds.append(correct_mark)
            else:
                confs.append(conf)
                preds.append(pred)
        except Exception as e:
            confs.append(0.0)
            preds.append(0)
            logging.warning(f"extract_confidence_and_label报错, 样本id={sample_id} 解析evaluation_list第{idx+1}项失败: {item}, 错误: {e}, confs.append(0.0), preds.append(0)")
    return confs, preds

def extract_confidence_and_label_1(evaluation_list, threshold=0.5, sample_id=None):
    confs = []
    preds = []
    for idx, item in enumerate(evaluation_list):
        try:
            # 置信度在第3个元素
            conf = float(item[2])
            # 正确性标记在第2个元素
            correct_mark = int(item[1]) # 模型预测的
            # 置信度大于阈值时，预测为正确(1)，否则为错误(0)
            pred = 1 if conf > threshold else 0 # 根据置信度换算的
            # 检查置信度和正确性标记是否一致
            # 新增规则：置信度等于阈值时，无论预测为1还是0，都认为与正确性标记一致
            if conf == threshold:
                confs.append(conf)
                preds.append(correct_mark)
            elif pred != correct_mark:
                logging.warning(f"警告：样本id={sample_id} 第{idx+1}步置信度({conf})与正确性标记({correct_mark})不一致，已自动修正。item: {item}")
                conf_fixed = 1 - conf
                confs.append(conf_fixed)
                preds.append(correct_mark)
            else:
                confs.append(conf)
                preds.append(pred)
        except Exception as e:
            confs.append(0.0)
            preds.append(0)
            logging.warning(f"extract_confidence_and_label_1报错, 样本id={sample_id} 解析evaluation_list第{idx+1}项失败: {item}, 错误: {e}")
    return confs, preds

def merge_by_step_markers(evaluation_steps):
    merged = []
    current_texts = []
    current_scores = []
    current_confs = []

    def flush(step_header):
        if not current_texts:
            return
        merged_text = " ".join(current_texts)
        majority_label = Counter(current_scores).most_common(1)[0][0]
        avg_conf = sum(current_confs) / len(current_confs)
        merged.append([f"{step_header} {merged_text}", majority_label, avg_conf, ""])

    step_pattern = re.compile(r"Step\s+\d+:")

    step_header = None
    for entry in evaluation_steps:
        text, score, conf, reason = entry
        match = step_pattern.match(text.strip())
        if match:
            if current_texts:
                flush(step_header)
            step_header = match.group()  # e.g., "Step 1:"
            # 剥离 step header 部分后，把剩下的加进去
            rest = text[len(step_header):].strip()
            current_texts = [rest] if rest else []
            current_scores = [score]
            current_confs = [conf]
        else:
            current_texts.append(text.strip())
            current_scores.append(score)
            current_confs.append(conf)

    flush(step_header)
    return merged

def grade_to_difficulty(grade_str):
    """
    将年级字符串映射为难度等级
    "Middle_School" → 1
    "High_School" → 2
    "Competition" → 3
    其他未知返回0
    """
    if grade_str == "Middle_School":
        return 1
    elif grade_str == "High_School":
        return 2
    elif grade_str == "Competition":
        return 3
    else:
        return 0

def calculate_standard_confidence(conf, pred_label):
    """
    计算标准置信度
    如果预测标签为1，置信度保持不变
    如果预测标签为0，置信度 = 1 - conf
    """
    return conf if pred_label == 1 else (1 - conf)

def analyze_confidence_gap_by_error_type(all_step_details):
    """
    对模型在不同错误类型上的置信度表现进行分析。
    输入：all_step_details 是包含每个步骤详细信息的列表，每项为：{
        'text': str, 步骤文本
        'conf': float, 模型预测置信度
        'pre_label': int, 模型预测标签（0或1）
        'gt_label': int, 真实标签（0或1）
        'error_type': str, 错误类型（若为正确步骤则为空）
        'id': str, 所属样本id
    }
    输出：打印置信度差值并绘图。
    """

    # # 1. 计算标准置信度
    # for step in all_step_details:
    #     step["standard_conf"] = calculate_standard_confidence(step["conf"], step["pre_label"])
    
    # 2. 提取所有正确步骤的置信度均值
    confs_correct = [d["conf"] for d in all_step_details if d.get("gt_label", 0) == 1]
    conf_correct = np.mean(confs_correct) if confs_correct else 0.0

    # 3. 将所有错误步骤按错误类型分类，统计其置信度
    error_type_to_confs = defaultdict(list)
    for d in all_step_details:
        if d.get("gt_label", 0) == 0:
            error_type = d.get("error_type", "").strip() or "(empty)"
            error_type_to_confs[error_type].append(d["conf"])

    # 4. 计算各错误类型的置信度均值
    conf_error_type = {
        etype: np.mean(confs) if confs else 0.0
        for etype, confs in error_type_to_confs.items()
    }

    # 4. 计算 Δconf[type] = conf_correct - conf_error(type)
    delta_conf = {
        etype: conf_correct - conf_err
        for etype, conf_err in conf_error_type.items()
    }

    # 5. 排序并打印结果（同时log和print输出）
    sorted_delta_conf = sorted(delta_conf.items(), key=lambda x: x[1], reverse=True)
    logging.info("\n模型对不同错误类型的置信度敏感性分析（Δconf = 正确步骤均值 - 该类错误步骤均值）：")
    logging.info(f"正确步骤平均置信度 conf_correct = {conf_correct:.4f}")
    logging.info(f"{'错误类型':<30} {'错误均值':>10} {'Δconf':>10}")
    print("\n模型对不同错误类型的置信度敏感性分析（Δconf = 正确步骤均值 - 该类错误步骤均值）：")
    print(f"正确步骤平均置信度 conf_correct = {conf_correct:.4f}")
    print(f"{'错误类型':<30} {'错误均值':>10} {'Δconf':>10}")
    for etype, dconf in sorted_delta_conf:
        logging.info(f"{etype:<30} {conf_error_type[etype]:>10.4f} {dconf:>10.4f}")
        print(f"{etype:<30} {conf_error_type[etype]:>10.4f} {dconf:>10.4f}")

    return sorted_delta_conf

def calculate_adversarial_metrics(stepwise_conf_records, power=5):
    """
    计算对抗扰动相关的八个步骤级指标（全部以步骤为单位）
    
    参数:
        stepwise_conf_records: 包含扰动前后置信度记录的列表
        power: 非线性惩罚的幂指数（默认2，平方放大）
        
    返回:
        包含八个步骤级指标的字典
    """
    if not stepwise_conf_records:
        return {}

    # 转换为DataFrame
    df = pd.DataFrame(stepwise_conf_records)

    # 1. Δconf_avg: 所有步骤置信度变化的绝对值的平均值
    df['delta_conf'] = (df['stdconf_original'] - df['stdconf_perturbed']).abs()
    delta_conf_avg = df['delta_conf'].mean()

    # 2. 线性放大变化值（直接乘以常数power）
    df['punished_delta'] = df['delta_conf'] * power

    # 2.1 线性放大变化值的加权平均（权重为原始置信度）
    weighted_punished_delta = (df['punished_delta'] * df['stdconf_original']).sum()
    total_weight = df['stdconf_original'].sum()
    linear_conf_weighted = weighted_punished_delta / total_weight if total_weight > 0 else 0.0

    # 2.2 线性放大变化值的直接平均
    linear_conf_mean = df['punished_delta'].mean()

    # 3. CriticalStepRate: 临界步骤比例（置信度下降超过阈值的步骤比例，替代原样本级指标）
    threshold = 0.2  # 置信度下降阈值
    critical_step_rate = (df['delta_conf'] > threshold).mean() if not df.empty else 0.0

    # 4. PRI_correct_only: 仅针对原始预测正确的步骤的相对变化
    if 'pre_label' in df.columns and 'gt_label' in df.columns:
        correct_steps = df[df['pre_label'] == df['gt_label']]
        pri_correct = correct_steps['delta_conf'].mean() if not correct_steps.empty else np.nan
    else:
        pri_correct = np.nan

    # 5. Δconf_mean_on_changed: 只计算发生变化的步骤的平均变化幅度
    changed_df = df[df['delta_conf'] > 0.01]
    delta_conf_mean_on_changed = changed_df['delta_conf'].mean() if not changed_df.empty else 0.0

    # 6. Δconf_changed_ratio_step: 置信度发生变化的步骤比例（以步骤为单位，替代原样本级指标）
    changed_ratio_step = (df['delta_conf'] > 0.01).mean() if not df.empty else 0.0

    return {
        "Δconf_avg": delta_conf_avg,
        "Δconf_nonlinear_weighted": linear_conf_weighted,
        "Δconf_nonlinear_mean": linear_conf_mean,
        "CriticalStepRate": critical_step_rate,  # 步骤级
        "PRI_correct_only": pri_correct,         # 步骤级
        "Δconf_mean_on_changed": delta_conf_mean_on_changed,  # 步骤级
        "Δconf_changed_ratio_step": changed_ratio_step,       # 步骤级
        "punish_power": power
    }

# 计算标准置信度与位置的相关性
def compute_stdconf_position_correlation_mixed(step_details_with_stdconf_and_pos):
    """
    计算每步推理的标准置信度与归一化位置之间的Spearman相关系数（全部步骤混合/正确/错误）
    输入: step_details_with_stdconf_and_pos: list of dict, 包含 'stdconf', 'step_idx', 'num_steps', 'gt_label'
    返回: corr, corr_correct, corr_error
    """
    stdconfs, positions = [], []
    stdconfs_correct, positions_correct = [], []
    stdconfs_error, positions_error = [], []
    for d in step_details_with_stdconf_and_pos:
        n = d.get("num_steps", None)
        i = d.get("step_idx", None)
        stdconf = d.get("stdconf", None)
        gt_label = d.get("gt_label", None)
        if n is None or i is None or n <= 1 or stdconf is None:
            continue
        pos = i / (n - 1)
        # 全部
        stdconfs.append(stdconf)
        positions.append(pos)
        # 正确/错误分开
        if gt_label == 1:
            stdconfs_correct.append(stdconf)
            positions_correct.append(pos)
        elif gt_label == 0:
            stdconfs_error.append(stdconf)
            positions_error.append(pos)
    corr = spearmanr(stdconfs, positions)[0] if len(stdconfs) > 1 else np.nan
    corr_correct = spearmanr(stdconfs_correct, positions_correct)[0] if len(stdconfs_correct) > 1 else np.nan
    corr_error = spearmanr(stdconfs_error, positions_error)[0] if len(stdconfs_error) > 1 else np.nan
    return corr, corr_correct, corr_error

# 计算标准置信度与token数的相关性
def compute_stdconf_token_correlation_mixed(step_details_with_stdconf_and_token):
    """
    计算每步推理的标准置信度与token数之间的Spearman相关系数（全部步骤混合/正确/错误）
    输入: step_details_with_stdconf_and_token: list of dict, 包含 'stdconf', 'token_count', 'gt_label'
    返回: corr, corr_correct, corr_error
    """
    stdconfs, tokens = [], []
    stdconfs_correct, tokens_correct = [], []
    stdconfs_error, tokens_error = [], []
    for d in step_details_with_stdconf_and_token:
        token_count = d.get("token_count", None)
        stdconf = d.get("stdconf", None)
        gt_label = d.get("gt_label", None)
        if token_count is None or stdconf is None:
            continue
        # 全部
        stdconfs.append(stdconf)
        tokens.append(token_count)
        # 正确/错误分开
        if gt_label == 1:
            stdconfs_correct.append(stdconf)
            tokens_correct.append(token_count)
        elif gt_label == 0:
            stdconfs_error.append(stdconf)
            tokens_error.append(token_count)
    corr = spearmanr(stdconfs, tokens)[0] if len(stdconfs) > 1 else np.nan
    corr_correct = spearmanr(stdconfs_correct, tokens_correct)[0] if len(stdconfs_correct) > 1 else np.nan
    corr_error = spearmanr(stdconfs_error, tokens_error)[0] if len(stdconfs_error) > 1 else np.nan
    return corr, corr_correct, corr_error

def count_tokens_cl100k_base(text):
    if not _tiktoken_available:
        return None
    enc = tiktoken.get_encoding("cl100k_base")
    return len(enc.encode(text))

def process_sample(row, if_adv, model_name, max_tokens):
    """
    处理单个样本的函数，用于多进程处理
    """
    # 初始化结果字典
    result = {
        "confs": [],
        "preds": [],
        "gts": [],
        "difficulty": 0,
        "step_details": [],
        "stepwise_conf_records": [],
        "step_details_with_stdconf_and_pos": [],
        "step_details_with_stdconf_and_token": []
    }
    
    try:
        # 获取当前样本的唯一id
        sample_id = row.get("id", id(row))
        
        problem = row["question"]
        final_answer = row["final_answer"]
        data_type = row["data_type"] # 模态信息
        process_evaluation = row["process_evaluation"]
        
        # 处理 process_evaluation，去除内容为'     \\['或'     \\]'的元素
        if isinstance(process_evaluation, (list, np.ndarray)):
            cleaned_process_evaluation = []
            for item in process_evaluation:
                if isinstance(item, (list, np.ndarray)) and len(item) > 0:
                    item_text = item[0] if isinstance(item[0], str) else str(item[0])
                    item_text_stripped = item_text.strip()
                    if item_text_stripped not in ['\\[', '\\]']:
                        cleaned_process_evaluation.append(item)
            process_evaluation = cleaned_process_evaluation

        # 读取年级字段并量化为难度
        grade_str = row["grade"]
        difficulty = grade_to_difficulty(grade_str)
        result["difficulty"] = difficulty

        # 处理图像路径
        image_paths = []
        if data_type != "pure_text":
            img_dict = row.get("image_paths", {})
            if isinstance(img_dict, dict):
                image_paths = [v for v in img_dict.values() if v and isinstance(v, str) and os.path.exists(v)]
            elif isinstance(img_dict, list):
                image_paths = [v for v in img_dict if v and isinstance(v, str) and os.path.exists(v)]

        # 构建消息历史（原始）
        messages = build_messages(
            problem=problem,
            final_answer=final_answer,
            process_evaluation=process_evaluation,
            student_solution=None
        )

        # messages= [
        # {
        #     "role": "user",
        #     "content": "请问图中内容是什么?"
        # }]
        
        # 初始化OpenAI客户端（每个进程独立）
        client = openai.OpenAI(
            api_key=os.getenv("OPENAI_API_KEY", ""),
            base_url=os.getenv("OPENAI_API_BASE", "")
        )
        
        #--------------------------调用MLLM进行推理-----------------------------#
        if image_paths:
            output = call_gpt4o_multimodal(image_paths, messages, client, model_name, max_tokens=max_tokens)
        else:
            output = call_gpt4o_text(messages, client, model_name, max_tokens=max_tokens)
        #----------------------------------------------------------------------#

        #----------------------------------后处理--------------------------- ---#
        # 解析LLM输出
        eval_list = clean_and_parse_json(output)
        if eval_list is None:
            logging.warning(f"样本 (id={sample_id}) LLM输出解析失败")
            return result

        # 提取置信度和预测标签
        confs, preds = extract_confidence_and_label_1(eval_list, threshold=0.5, sample_id=sample_id)
        #----------------------------------------------------------------------#

        # 提取真实标签
        if process_evaluation is not None:
            gts = extract_gt_labels(process_evaluation)
        else:
            logging.warning(f"样本 (id={sample_id}) 缺少process_evaluation字段")
            return result

        #==================================检查长度一致性=====================================#
        if len(confs) == len(preds) == len(gts):
            pass  # 长度一致，继续后续流程
        elif len(confs) == len(preds) > len(gts):
            try:
                output_clean = output.strip().removeprefix("<evaluation>\n").removesuffix("\n</evaluation>")
                evaluation_steps = ast.literal_eval(output_clean)
                merged_steps = merge_by_step_markers(evaluation_steps)
                confs, preds = extract_confidence_and_label_1(merged_steps, threshold=0.5, sample_id=sample_id)
                
                if len(confs) == len(preds) == len(gts):
                    resolution_status = "已解决"
                    eval_list = merged_steps
                else:
                    resolution_status = "未解决"
                    logging.warning(f"样本 (id={sample_id}) 长度不一致: confs={len(confs)}, preds={len(preds)}, gts={len(gts)} - {resolution_status}")
                    return result
            except Exception as e:
                resolution_status = "解析失败"
                logging.warning(f"样本 (id={sample_id}) 长度不一致: confs={len(confs)}, preds={len(preds)}, gts={len(gts)}, 尝试解决时解析出错: {str(e)}")
                return result
        elif len(confs) == len(preds) < len(gts):
            resolution_status = "未解决"
            logging.warning(f"样本 (id={sample_id}) 长度不一致: confs={len(confs)}, preds={len(preds)}, gts={len(gts)} - {resolution_status}")
            return result
        else:
            resolution_status = "无法解决的不一致情况"
            logging.warning(f"样本 (id={sample_id}) 长度不一致且无法解决: confs={len(confs)}, preds={len(preds)}, gts={len(gts)} - 可能confs和preds长度也不一致")
            return result
        #==================================检查长度一致性=====================================#

        result["confs"] = confs
        result["preds"] = preds
        result["gts"] = gts

        # 收集每个步骤的详细信息
        subject = row["subject"]  # 获取样本的subject字段
        num_steps = len(eval_list)
        for step_idx, item in enumerate(eval_list):
            error_type = ""
            if process_evaluation is not None and step_idx < len(process_evaluation):
                step_info = process_evaluation[step_idx]
                if len(step_info) >= 3:
                    error_type = str(step_info[2]).strip()
            
            # 统计token数
            step_text = ""
            if process_evaluation is not None and step_idx < len(process_evaluation):
                if isinstance(process_evaluation[step_idx][0], str):
                    step_text = process_evaluation[step_idx][0]
            token_count = count_tokens_cl100k_base(step_text) if _tiktoken_available else None
            
            # 计算标准置信度
            std_conf = calculate_standard_confidence(confs[step_idx], preds[step_idx])
            
            step_detail = {
                "conf": float(confs[step_idx]),
                "stdconf": float(std_conf),  # 标准置信度
                "pre_label": int(preds[step_idx]),
                "gt_label": int(gts[step_idx]),
                "error_type": error_type,
                "id": sample_id,
                "difficulty": difficulty,  # 难度信息
                "step_idx": step_idx,      # 步骤索引
                "data_type": data_type,    # 新增：添加模态信息  <--- 添加这一行
                "subject": subject,        # 新增：科目信息  <--- 添加这一行
            }
            result["step_details"].append(step_detail)
            
            # 收集标准置信度与位置信息
            result["step_details_with_stdconf_and_pos"].append({
                "stdconf": float(std_conf),
                "gt_label": int(gts[step_idx]),
                "sample_id": sample_id,
                "step_idx": step_idx,
                "num_steps": num_steps
            })
            
            # 收集标准置信度与token数信息
            if token_count is not None:
                result["step_details_with_stdconf_and_token"].append({
                    "stdconf": float(std_conf),
                    "gt_label": int(gts[step_idx]),
                    "token_count": token_count
                })
        
        #---------------------------扰动需要进行额外的推理------------------------------#
        if if_adv:
            perturbed_eval_list = None
            perturbed_confs = None

            # 1.1 选择扰动后的process_evaluation内容
            perturbed_solution = None
            perturbed_image_paths = []

            # 依次检查每种扰动类型
            perturbed_field = None
            if row.get("student_solution_synonym") is not None:
                perturbed_field = "student_solution_synonym"
            elif row.get("student_solution_structure") is not None:
                perturbed_field = "student_solution_structure"
            elif row.get("perturbed_image_paths") is not None:
                perturbed_field = "perturbed_image"
            else:
                logging.warning(f"样本 (id={sample_id}) 未找到可用的扰动字段。")
                return result


            if perturbed_field == "perturbed_image":
                perturbed_solution = process_evaluation  # 还是原始步骤
                img_dict = row.get("perturbed_image_paths", {})
                if isinstance(img_dict, dict):
                    perturbed_image_paths = [v for v in img_dict.values() if v and isinstance(v, str) and os.path.exists(v)]
                elif isinstance(img_dict, list):
                    perturbed_image_paths = [v for v in img_dict if v and isinstance(v, str) and os.path.exists(v)]
            else:
                perturbed_solution = row.get(perturbed_field, None)
                
                if isinstance(perturbed_solution, (list, np.ndarray)):
                    cleaned_perturbed_solution = []
                    for item in perturbed_solution:
                        if isinstance(item, str):
                            item_stripped = item.strip()
                            if item_stripped not in ['\\[', '\\]']:
                                cleaned_perturbed_solution.append(item)
                    perturbed_solution = cleaned_perturbed_solution  

                perturbed_image_paths = image_paths  # 图像不变



            # 1.2 仅当扰动内容存在时才推理
            if perturbed_solution is None:
                logging.warning(f"样本 (id={sample_id}) 扰动内容不存在，无法进行推理。")
                return result
            
            if perturbed_field == "perturbed_image":
                perturbed_messages = build_messages(
                    problem=problem,
                    final_answer=final_answer,
                    process_evaluation=perturbed_solution,
                    student_solution=None
                )
            else:
                perturbed_messages = build_messages(
                    problem=problem,
                    final_answer=final_answer,
                    process_evaluation=None,
                    student_solution=perturbed_solution
                )

            try:
                if perturbed_image_paths:
                    perturbed_output = call_gpt4o_multimodal(perturbed_image_paths, perturbed_messages, client, model_name, max_tokens=max_tokens)
                else:
                    perturbed_output = call_gpt4o_text(perturbed_messages, client, model_name, max_tokens=max_tokens)
            except Exception as e:
                logging.warning(f"样本 (id={sample_id}) 扰动推理API调用异常: {e}")
                return result

            perturbed_eval_list = clean_and_parse_json(perturbed_output)
            if perturbed_eval_list is None:
                logging.warning(f"样本 (id={sample_id}) 扰动LLM输出解析失败")
                return result              
            else:
                perturbed_confs, perturbed_preds = extract_confidence_and_label_1(perturbed_eval_list, threshold=0.5, sample_id=sample_id)
            
            # ======== 新增：针对扰动输出的长度检查和处理 ======== #
            if len(perturbed_confs) == len(perturbed_preds) == len(gts):
                pass  # 长度一致，继续后续流程
            elif len(perturbed_confs) == len(perturbed_preds) > len(gts):
                try:
                    perturbed_output_clean = perturbed_output.strip().removeprefix("<evaluation>\n").removesuffix("\n</evaluation>")
                    perturbed_evaluation_steps = ast.literal_eval(perturbed_output_clean)
                    perturbed_merged_steps = merge_by_step_markers(perturbed_evaluation_steps)
                    perturbed_confs, perturbed_preds = extract_confidence_and_label_1(
                        perturbed_merged_steps, threshold=0.5, sample_id=sample_id
                    )
                    
                    if len(perturbed_confs) == len(perturbed_preds) == len(gts):
                        resolution_status = "已解决"
                        perturbed_eval_list = perturbed_merged_steps
                        logging.warning(f"样本 (id={sample_id}) 扰动输出长度不一致已解决")
                    else:
                        resolution_status = "未解决"
                        logging.warning(f"样本 (id={sample_id}) 扰动输出长度不一致: confs={len(perturbed_confs)}, preds={len(perturbed_preds)}, gts={len(gts)} - {resolution_status}")
                        perturbed_confs = None
                except Exception as e:
                    resolution_status = "解析失败"
                    logging.warning(f"样本 (id={sample_id}) 扰动输出长度不一致: confs={len(perturbed_confs)}, preds={len(perturbed_preds)}, gts={len(gts)}, 尝试解决时解析出错: {str(e)}")
                    perturbed_confs = None
            elif len(perturbed_confs) == len(perturbed_preds) < len(gts):
                resolution_status = "未解决"
                logging.warning(f"样本 (id={sample_id}) 扰动输出长度不一致: confs={len(perturbed_confs)}, preds={len(perturbed_preds)}, gts={len(gts)} - {resolution_status}")
                perturbed_confs = None
            else:
                resolution_status = "无法解决的不一致情况"
                logging.warning(f"样本 (id={sample_id}) 扰动输出长度不一致且无法解决: confs={len(perturbed_confs)}, preds={len(perturbed_preds)}, gts={len(gts)} - 可能confs和preds长度也不一致")
                perturbed_confs = None
            # ======== 新增结束 ======== #
    
            # 检查长度一致性
            if perturbed_confs is not None and len(confs) == len(perturbed_confs):
                for step_idx, conf_original in enumerate(confs):
                    conf_perturbed = perturbed_confs[step_idx]
                    # 计算标准置信度
                    std_conf_original = calculate_standard_confidence(conf_original, preds[step_idx])
                    std_conf_perturbed = calculate_standard_confidence(conf_perturbed, perturbed_preds[step_idx])
                    
                    # 添加 pre_label 和 gt_label 信息
                    result["stepwise_conf_records"].append({
                        "sample_id": sample_id,
                        "conf_original": float(conf_original),
                        "conf_perturbed": float(conf_perturbed),
                        "stdconf_original": float(std_conf_original),
                        "stdconf_perturbed": float(std_conf_perturbed),
                        "perturbed_field": perturbed_field,
                        "pre_label": int(preds[step_idx]),  # 添加预测标签
                        "gt_label": int(gts[step_idx])      # 添加真实标签
                    })
            else:
                logging.warning(f"样本 (id={sample_id}) confs与perturbed_confs长度不一致或perturbed_confs为None: len(confs)={len(confs)}, len(perturbed_confs)={len(perturbed_confs) if perturbed_confs is not None else 'None'}")
    
    except Exception as e:
        logging.error(f"处理样本 (id={sample_id}) 时出错: {str(e)}")
    
    return result

# 添加按subject分组的ECE计算函数
def calculate_metrics_by_subject(step_details):
    """
    按科目分组计算准确率和ECE
    """
    # 按科目分组
    groups = defaultdict(list)
    for step in step_details:
        if 'subject' in step:
            groups[step['subject']].append(step)
    
    results = {}
    for subject, steps in groups.items():
        # 使用标准置信度计算ECE
        stdconfs = [s['stdconf'] for s in steps]
        preds = [s['pre_label'] for s in steps]
        gts = [s['gt_label'] for s in steps]
        acc = np.mean(np.array(preds) == np.array(gts))
        ece = compute_binary_ece(stdconfs, preds, gts)
        results[subject] = {
            'accuracy': acc, 
            'ECE': ece,
            'num_steps': len(steps)
        }
    return results

# 新增分析函数
def calculate_metrics_by_difficulty(step_details):
    """
    按难度分组计算准确率和ECE
    """
    # 按难度分组
    groups = defaultdict(list)
    for step in step_details:
        if 'difficulty' in step:
            groups[step['difficulty']].append(step)
    
    results = {}
    for difficulty, steps in groups.items():
        # 使用标准置信度计算ECE
        stdconfs = [s['stdconf'] for s in steps]
        preds = [s['pre_label'] for s in steps]
        gts = [s['gt_label'] for s in steps]
        acc = np.mean(np.array(preds) == np.array(gts))
        ece = compute_binary_ece(stdconfs, preds, gts)
        results[difficulty] = {'accuracy': acc, 'ECE': ece}
    return results

def calculate_metrics_by_error_type(step_details):
    """ 
    计算每种错误类型中模型预测的准确率和ECE
    """
    # 按错误类型分组
    error_type_groups = defaultdict(list)
    for step in step_details:
        if step['gt_label'] == 0:  # 只考虑错误步骤
            error_type = step.get('error_type', 'Unknown') #error_type是真实的error_type
            error_type_groups[error_type].append(step)
    
    results = {}
    for error_type, steps in error_type_groups.items():
        # 使用标准置信度计算ECE
        stdconfs = [s['stdconf'] for s in steps]
        preds = [s['pre_label'] for s in steps]
        gts = [s['gt_label'] for s in steps]
        acc = np.mean(np.array(preds) == np.array(gts))
        ece = compute_binary_ece(stdconfs, preds, gts)
        results[error_type] = {'accuracy': acc, 'ECE': ece}
    return results

# 在指标计算部分添加以下函数
def calculate_macro_f1(y_true, y_pred):
    """
    计算Macro F1分数
    - 分别计算正类(正确步骤)和负类(错误步骤)的F1分数
    - 然后取平均值作为Macro F1
    """
    # 确保输入为整数数组
    y_true = np.array(y_true).astype(int)
    y_pred = np.array(y_pred).astype(int)
    
    # 计算每个类别的F1
    f1_pos = f1_score(y_true, y_pred, pos_label=1, zero_division=0)
    f1_neg = f1_score(y_true, y_pred, pos_label=0, zero_division=0)
    
    # 计算Macro F1
    macro_f1 = (f1_pos + f1_neg) / 2
    
    return {
        "F1_correct_steps": f1_pos,
        "F1_incorrect_steps": f1_neg,
        "Macro_F1": macro_f1
    }

# 在指标计算部分添加以下函数
def calculate_metrics_by_data_type(step_details):
    """
    按模态分组计算准确率和ECE
    """
    # 按模态分组
    groups = defaultdict(list)
    for step in step_details:
        if 'data_type' in step:
            groups[step['data_type']].append(step)
    
    results = {}
    for data_type, steps in groups.items():
        # 使用标准置信度计算ECE
        stdconfs = [s['stdconf'] for s in steps]
        preds = [s['pre_label'] for s in steps]
        gts = [s['gt_label'] for s in steps]
        acc = np.mean(np.array(preds) == np.array(gts))
        ece = compute_binary_ece(stdconfs, preds, gts)
        results[data_type] = {
            'accuracy': acc, 
            'ECE': ece,
            'num_steps': len(steps)
        }
    return results

# ============================== 综合指标计算函数 ==============================

def calculate_css(adversarial_metrics, scale_factor_mean=5.0, scale_factor_critical=5.0):
    """
    计算 Confidence Stability Score (CSS)，基于三大核心指标：
    Δconf_changed_ratio_step、Δconf_mean_on_changed、CriticalStepRate
    
    参数:
        adversarial_metrics: calculate_adversarial_metrics 返回的指标字典
        scale_factor_mean: Δconf_mean_on_changed 的放大因子（默认10）
        scale_factor_critical: CriticalStepRate 的放大因子（默认10）
        
    返回:
        css: Confidence Stability Score (0-1)，值越大越稳定
        stability_level: 稳定性级别描述
    """
    # 处理NaN值
    for metric_name in ["Δconf_changed_ratio_step", "Δconf_mean_on_changed", "CriticalStepRate"]:
        if np.isnan(adversarial_metrics.get(metric_name, 0)):
            adversarial_metrics[metric_name] = 0
    
    # 提取子指标
    delta_conf_changed_ratio_step = adversarial_metrics.get("Δconf_changed_ratio_step", 0)
    delta_conf_mean_on_changed = adversarial_metrics.get("Δconf_mean_on_changed", 0)
    critical_step_rate = adversarial_metrics.get("CriticalStepRate", 0)
       
    # 放大指标，保证数值不超过1。并检查是否在[0,1]之间；如果为0或1也记录
    delta_conf_mean_on_changed_scaled = delta_conf_mean_on_changed * scale_factor_mean
    if delta_conf_mean_on_changed_scaled > 1:
        delta_conf_mean_on_changed_scaled = 1
        logging.warning(f"delta_conf_mean_on_changed_scaled 超出[0,1]范围，已截断为1: {delta_conf_mean_on_changed_scaled}")
    elif delta_conf_mean_on_changed_scaled < 0:
        delta_conf_mean_on_changed_scaled = 0
        logging.warning(f"delta_conf_mean_on_changed_scaled 超出[0,1]范围，已截断为0: {delta_conf_mean_on_changed_scaled}")
        
    if delta_conf_mean_on_changed_scaled == 1 or delta_conf_mean_on_changed_scaled == 0: #被截断后仍然会触发记录
        logging.info(f"delta_conf_mean_on_changed_scaled 为边界值: {delta_conf_mean_on_changed_scaled}")

    critical_step_rate_scaled = critical_step_rate * scale_factor_critical
    if critical_step_rate_scaled > 1:
        critical_step_rate_scaled = 1
        logging.warning(f"critical_step_rate_scaled 超出[0,1]范围，已截断为1: {critical_step_rate_scaled}")
    elif critical_step_rate_scaled < 0:
        critical_step_rate_scaled = 0
        logging.warning(f"critical_step_rate_scaled 超出[0,1]范围，已截断为0: {critical_step_rate_scaled}")
    if critical_step_rate_scaled == 1 or critical_step_rate_scaled == 0:
        logging.info(f"critical_step_rate_scaled 为边界值: {critical_step_rate_scaled}")
    
    # 计算 CSS（越大越好）
    css = (
        0.4 * (1 - min(delta_conf_changed_ratio_step, 1)) +
        0.4 * (1 - delta_conf_mean_on_changed_scaled) +
        0.2 * (1 - critical_step_rate_scaled)
    )
    
    # 确定稳定性级别
    if css >= 0.9:
        stability_level = "Highly Stable"
    elif css >= 0.7:
        stability_level = "Reasonably Stable"
    elif css >= 0.5:
        stability_level = "Moderately Stable"
    else:
        stability_level = "Unstable under perturbation"
    
    return css, stability_level


def calculate_cess_from_sorted_delta_conf(sorted_delta_conf):
    """
    基于已计算好的 sorted_delta_conf 计算 CESS（Confidence Error Sensitivity Score）
    
    参数:
        sorted_delta_conf: 列表，每个元素是 (error_type, delta_conf)，
                           delta_conf 已经是置信度差距（正确步骤置信度均值 - 错误步骤置信度均值）
                           
    返回:
        cess: Confidence Error Sensitivity Score
    """
    if not sorted_delta_conf:
        return np.nan  # 空输入，返回 nan
    
    # 提取 delta_conf 值
    delta_probs = [delta_conf for error_type, delta_conf in sorted_delta_conf]
    
    # 计算 CESS（平均置信度差距）
    cess = np.mean(delta_probs)
    
    return cess

def calculate_cmcs(ece, ece_pos, ece_neg, scale_factor=5.0):
    """
    计算 Confidence Multiview Calibration Score (CMCS) —— 带类间ECE差距放大因子
    
    参数:
        step_details: 步骤详细信息列表（保留但不使用）
        ece: 全局ECE
        ece_pos: 正类ECE
        ece_neg: 负类ECE
        scale_factor: 类间ECE差距放大因子（默认5.0）
        
    返回:
        cmcs: Confidence Multiview Calibration Score (0-1)
        calibration_level: 校准级别描述
    """
    # 计算类间校准差距
    delta_ece_cls = abs(ece_pos - ece_neg) #ECE 值始终在 [0, 1] 范围内，
    ece_scale = ece * scale_factor
    # 保证数值在[0,1]之间
    if ece_scale < 0:
        logging.info(f"放大后ECE小于0，已截断为0，原值: {ece_scale}")
        ece_scale = 0.0
    elif ece_scale > 1:
        logging.info(f"放大后ECE大于1，已截断为1，原值: {ece_scale}")
        ece_scale = 1.0
    
    if ece_scale == 0 or ece_scale == 1: #被截断后仍然会触发记录
        logging.info(f"放大后ECE为边界值: {ece_scale}")
    
    # 计算CMCS
    cmcs = (
        0.5 * (1 - ece_scale) +
        0.5 * (1 - delta_ece_cls)
    )
    
    # 确定校准级别
    if cmcs >= 0.85:
        calibration_level = "Excellent calibration"
    elif cmcs >= 0.65:
        calibration_level = "Good calibration"
    elif cmcs >= 0.45:
        calibration_level = "Moderate calibration"
    else:
        calibration_level = "Poor calibration"
    
    return cmcs, calibration_level


if __name__ == "__main__":
    
    start_time = time.time()  # 统计程序开始时间

    # ========== 命令行参数 ==========
    parser = argparse.ArgumentParser()
    parser.add_argument("--adv", type=str, default="True", choices=["True", "False"], help="是否进行对抗扰动")
    parser.add_argument("--model_name", type=str, default="gpt-4o", help="指定使用的模型名称")
    parser.add_argument("--model_name_loguse", type=str, default="gpt-4o", help="日志文件名中使用的模型名称")
    parser.add_argument("--data_path", type=str, default="data/conf_test.parquet", help="指定样本数据的parquet文件路径")
    parser.add_argument("--max_tokens", type=int, default=5000, help="LLM生成的最大token数")
    parser.add_argument("--batch_size", type=int, default=100, help="每个批次的样本大小") #100
    parser.add_argument("--num_workers", type=int, default=32, help="并行工作进程数") #32
    parser.add_argument("--log_dir", type=str, default="log_eval", help="日志文件目录")

    args = parser.parse_args()
    if_adv = args.adv
    model_name = args.model_name
    data_path = args.data_path
    max_tokens = args.max_tokens
    batch_size = args.batch_size
    num_workers = args.num_workers
    log_dir = args.log_dir
    model_name_loguse = args.model_name_loguse
    
    # 确保日志文件夹存在
    os.makedirs(log_dir, exist_ok=True)
    # 日志文件名中包含model_name和当前日期时间
    log_filename = f"{model_name_loguse}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    # log_filename = f"QVQ-72B-Preview_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_path = os.path.join(log_dir, log_filename)
    logging.basicConfig(
        filename=log_path,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )

    # 读取样本数据
    df = pd.read_parquet(data_path)
    total_samples = len(df)
    
    # 打印配置信息
    print(f"总样本数: {total_samples}")
    print(f"批大小: {batch_size}")
    print(f"工作进程数: {num_workers}")
    print(f"模型名称: {model_name}")
    print(f"日志路径: {log_path}")
    logging.info(f"总样本数: {total_samples}")
    logging.info(f"批大小: {batch_size}")
    logging.info(f"工作进程数: {num_workers}")
    logging.info(f"模型名称: {model_name}")

    # 将DataFrame转换为字典列表，便于分批次处理
    samples = df.to_dict('records')
    
    # 分批次处理样本
    batches = [samples[i:i + batch_size] for i in range(0, total_samples, batch_size)]
    
    # 初始化结果收集器
    all_results = {
        "confs": [],
        "preds": [],
        "gts": [],
        "difficulties": [],
        "step_details": [],
        "stepwise_conf_records": [],
        "step_details_with_stdconf_and_pos": [],
        "step_details_with_stdconf_and_token": []
    }
    
    # for sample in batches[0]:  # 只跑第一个批次
    #     result = process_sample(sample, if_adv, model_name, max_tokens)
    #     print(result)  


    # 使用进程池并行处理批次
    with multiprocessing.Pool(processes=num_workers) as pool:       
        # 创建部分函数，固定参数
        process_func = partial(process_sample, if_adv=if_adv, model_name=model_name, max_tokens=max_tokens)
               
        # 处理所有批次
        for batch_idx, batch in enumerate(batches):
            logging.info(f"开始处理批次 {batch_idx+1}/{len(batches)}，包含 {len(batch)} 个样本")
            print(f"处理批次 {batch_idx+1}/{len(batches)}，样本数: {len(batch)}")
            
            # 处理当前批次
            batch_results = []
            try:
                batch_results = pool.map(process_func, batch)
            except Exception as e:
                logging.error(f"处理批次 {batch_idx+1} 时出错: {str(e)}")
                continue
            
            # 收集当前批次的结果
            for result in batch_results:
                all_results["confs"].extend(result["confs"])
                all_results["preds"].extend(result["preds"])
                all_results["gts"].extend(result["gts"])
                all_results["difficulties"].append(result["difficulty"])
                all_results["step_details"].extend(result["step_details"])
                all_results["stepwise_conf_records"].extend(result["stepwise_conf_records"])
                all_results["step_details_with_stdconf_and_pos"].extend(result["step_details_with_stdconf_and_pos"])
                all_results["step_details_with_stdconf_and_token"].extend(result["step_details_with_stdconf_and_token"])
     
    #---------------------------------------------------------------------------            
    # 将all_results存到一个json文件中，文件名包含模型名和当前日期时间，放在result文件夹下，没有则新建
    # 获取当前时间，格式化为字符串
    now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    # 构造result文件夹路径
    result_dir = "result"
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    # 构造文件名
    json_filename = os.path.join(result_dir, f"all_results_{model_name_loguse}_{now_str}.json")
    # json_filename = os.path.join(result_dir, f"all_results_QVQ-72B-Preview_{now_str}.json")

    # 为了避免numpy类型无法序列化，定义一个转换函数
    def convert_to_builtin_type(obj):
        if isinstance(obj, (np.integer,)):
            return int(obj)
        elif isinstance(obj, (np.floating,)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        return obj

    # 保存到json文件
    with open(json_filename, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2, default=convert_to_builtin_type)
    print(f"已将all_results保存到 {json_filename}")
    logging.info(f"已将all_results保存到 {json_filename}")
    # ------------------------------------------------------------------
    
    # 打印处理完成信息
    processed_steps = len(all_results["confs"])
    print(f"\n处理完成！总处理步骤数: {processed_steps}")
    logging.info(f"处理完成！总处理步骤数: {processed_steps}")
    
    # #####################################################计算指标##################################################################
    # ========== 0.分类性能指标: Macro F1 和 PRM-Score ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("0. Classification Performance Metrics: Macro F1 and PRM-Score")
    logging.info("0. Classification Performance Metrics: Macro F1 and PRM-Score")

    # 准备数据
    all_preds = [s["pre_label"] for s in all_results["step_details"]]
    all_gts = [s["gt_label"] for s in all_results["step_details"]]

    # 计算Macro F1
    macro_f1_metrics = calculate_macro_f1(all_gts, all_preds)
    print(f"\nMacro F1 指标:")
    print(f"正确步骤F1: {macro_f1_metrics['F1_correct_steps']:.4f}")
    print(f"错误步骤F1: {macro_f1_metrics['F1_incorrect_steps']:.4f}")
    print(f"Macro F1: {macro_f1_metrics['Macro_F1']:.4f}")
    logging.info(f"正确步骤F1: {macro_f1_metrics['F1_correct_steps']:.4f}")
    logging.info(f"错误步骤F1: {macro_f1_metrics['F1_incorrect_steps']:.4f}")
    logging.info(f"Macro F1: {macro_f1_metrics['Macro_F1']:.4f}")

    print("="*60)
    logging.info("="*60)
    
    # ========== 1.Confidence Calibration: 计算ECE ========== 
    print("\n" + "="*60)
    logging.info("="*60)
    print("1. Confidence Calibration: 计算ECE")
    logging.info("1. Confidence Calibration: 计算ECE")
    # 使用标准置信度计算ECE
    confs = [s["conf"] for s in all_results["step_details"]] #会在compute_binary_ece中转化为标准置信度
    preds = [s["pre_label"] for s in all_results["step_details"]]
    gts = [s["gt_label"] for s in all_results["step_details"]]
    
    ece = compute_binary_ece(
        np.array(confs),
        np.array(preds),
        np.array(gts),
        n_bins=10
    )
    acc = np.mean(np.array(preds) == np.array(gts))
    logging.info(f"总样本数: {len(gts)}")
    logging.info(f"ECE: {ece:.4f}")
    logging.info(f"准确率: {acc:.4f}")
    print(f"总样本数: {len(gts)}")
    print(f"ECE: {ece:.4f}")
    print(f"准确率: {acc:.4f}")
    print("="*60)
    logging.info("="*60)

    # ========== 2.prob Across Error Types：confidence_gap_by_error_type==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("2. prob Across Error Types")
    logging.info("2. prob Across Error Types")
    sorted_delta_conf = analyze_confidence_gap_by_error_type(all_results["step_details"])
    print("="*60)
    logging.info("="*60)

    # ========== 3.Difficulty Sensitivity: 计算平均标准置信度与任务难度等级的 Spearman 相关系数 ========== 
    print("\n" + "="*60)
    logging.info("="*60)
    print("3. Difficulty Sensitivity: 计算平均标准置信度与任务难度等级的 Spearman 相关系数")
    logging.info("3. Difficulty Sensitivity: 计算平均标准置信度与任务难度等级的 Spearman 相关系数")
    # 计算每个样本的平均标准置信度
    sample_avg_stdconfs = []
    sample_difficulties = []
    
    # 按样本分组
    sample_id_to_details = defaultdict(list)
    for step in all_results["step_details"]:
        sample_id_to_details[step["id"]].append(step)
    
    for sample_id, steps in sample_id_to_details.items():
        # 计算该样本所有步骤的平均标准置信度
        avg_stdconf = np.mean([s["stdconf"] for s in steps])
        # 获取难度（所有步骤难度相同）
        difficulty = steps[0]["difficulty"] if steps else 0
        
        sample_avg_stdconfs.append(avg_stdconf)
        sample_difficulties.append(difficulty)
    
    # 计算相关系数
    if len(sample_avg_stdconfs) > 1:
        spearman_corr, spearman_p = spearmanr(sample_avg_stdconfs, sample_difficulties)
        print(f"平均标准置信度与任务难度等级的 Spearman 相关系数: {spearman_corr:.4f} (p={spearman_p:.4g})")
        logging.info(f"平均标准置信度与任务难度等级的 Spearman 相关系数: {spearman_corr:.4f} (p={spearman_p:.4g})")
    else:
        print("有效样本数不足，无法计算 Spearman 相关系数。")
        logging.info("有效样本数不足，无法计算 Spearman 相关系数。")
    print("="*60)
    logging.info("="*60)

    # ========== 4.Advzersarial Calibration: Δconfavg ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("4. Advzersarial Calibration: Δconfavg")
    logging.info("4. Advzersarial Calibration: Δconfavg")
    if len(all_results["stepwise_conf_records"]) > 0:
        perturbed_types = ["student_solution_synonym", "student_solution_structure", "perturbed_image"]
        
        # 1. 整体计算（所有样本的步骤）
        print("\n整体对抗扰动指标:")
        logging.info("\n整体对抗扰动指标:")
        adversarial_metrics_all = calculate_adversarial_metrics(all_results["stepwise_conf_records"])
        for metric, value in adversarial_metrics_all.items():
            print(f"{metric}: {value:.4f}")
            logging.info(f"{metric}: {value:.4f}")
        
        # 2. 按扰动类型分组计算
        for perturbed_type in perturbed_types:
            # 过滤当前扰动类型的记录
            type_records = [r for r in all_results["stepwise_conf_records"] if r["perturbed_field"] == perturbed_type]
            
            if not type_records:
                print(f"\n扰动类型 {perturbed_type} 没有可用数据，无法计算对抗扰动指标。")
                logging.info(f"扰动类型 {perturbed_type} 没有可用数据，无法计算对抗扰动指标。")
                continue
                
            # 计算该扰动类型的对抗扰动指标
            print(f"\n扰动类型 {perturbed_type} 的对抗扰动指标:")
            logging.info(f"扰动类型 {perturbed_type} 的对抗扰动指标:")
            adversarial_metrics_type = calculate_adversarial_metrics(type_records)
            for metric, value in adversarial_metrics_type.items():
                print(f"{metric}: {value:.4f}")
                logging.info(f"{metric}: {value:.4f}")
    else:
        print("\n没有可用于计算对抗扰动指标的数据")
        logging.info("\n没有可用于计算对抗扰动指标的数据")
    print("="*60)
    logging.info("="*60)
    # ========== 4.Advzersarial Calibration ==========

    # ========== 5.Critic StdConf-Position Correlation (SPC) ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("5. Critic StdConf-Position Correlation (SPC)")
    logging.info("5. Critic StdConf-Position Correlation (SPC)")
    spc_mixed, spc_correct, spc_error = compute_stdconf_position_correlation_mixed(all_results["step_details_with_stdconf_and_pos"])
    print(f"Critic StdConf-Position Correlation (SPC-mixed): {spc_mixed if not np.isnan(spc_mixed) else 'NaN'}")
    print(f"Critic StdConf-Position Correlation (SPC-correct): {spc_correct if not np.isnan(spc_correct) else 'NaN'}")
    print(f"Critic StdConf-Position Correlation (SPC-error): {spc_error if not np.isnan(spc_error) else 'NaN'}")
    logging.info(f"Critic StdConf-Position Correlation (SPC-mixed): {spc_mixed if not np.isnan(spc_mixed) else 'NaN'}")
    logging.info(f"Critic StdConf-Position Correlation (SPC-correct): {spc_correct if not np.isnan(spc_correct) else 'NaN'}")
    logging.info(f"Critic StdConf-Position Correlation (SPC-error): {spc_error if not np.isnan(spc_error) else 'NaN'}")
    print("="*60)
    logging.info("="*60)

    # ========== 6.Critic StdConf-Token Count Correlation (STC) ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("6. Critic StdConf-Token Count Correlation (STC)")
    logging.info("6. Critic StdConf-Token Count Correlation (STC)")
    if _tiktoken_available:
        stc_mixed, stc_correct, stc_error = compute_stdconf_token_correlation_mixed(all_results["step_details_with_stdconf_and_token"])
        print(f"Critic StdConf-Token Count Correlation (STC-mixed): {stc_mixed if not np.isnan(stc_mixed) else 'NaN'}")
        print(f"Critic StdConf-Token Count Correlation (STC-correct): {stc_correct if not np.isnan(stc_correct) else 'NaN'}")
        print(f"Critic StdConf-Token Count Correlation (STC-error): {stc_error if not np.isnan(stc_error) else 'NaN'}")
        logging.info(f"Critic StdConf-Token Count Correlation (STC-mixed): {stc_mixed if not np.isnan(stc_mixed) else 'NaN'}")
        logging.info(f"Critic StdConf-Token Count Correlation (STC-correct): {stc_correct if not np.isnan(stc_correct) else 'NaN'}")
        logging.info(f"Critic StdConf-Token Count Correlation (STC-error): {stc_error if not np.isnan(stc_error) else 'NaN'}")
    else:
        print("未安装 tiktoken，无法计算标准置信度与token数的相关性。")
        logging.info("未安装 tiktoken，无法计算标准置信度与token数的相关性。")
    print("="*60)
    logging.info("="*60)
    
    # =====================7. 按难度分组的准确率和ECE=======================
    print("\n" + "="*60)
    logging.info("="*60)
    print("7. 按难度分组的准确率和ECE")
    logging.info("7. 按难度分组的准确率和ECE")
    metrics_by_difficulty = calculate_metrics_by_difficulty(all_results["step_details"])
    print("\n按难度分组的准确率和ECE:")
    logging.info("\n按难度分组的准确率和ECE:")
    for diff, metrics in metrics_by_difficulty.items():
        print(f"难度 {diff}: 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
        logging.info(f"难度 {diff}: 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
    print("="*60)
    logging.info("="*60)
    
    # ======================8. 不同错误类型的准确率和ECE========================
    print("\n" + "="*60)
    logging.info("="*60)
    print("8. 不同错误类型的准确率和ECE")
    logging.info("8. 不同错误类型的准确率和ECE")
    metrics_by_error_type = calculate_metrics_by_error_type(all_results["step_details"])
    print("\n不同错误类型的准确率和ECE:")
    logging.info("\n不同错误类型的准确率和ECE:")
    for error_type, metrics in metrics_by_error_type.items():
        print(f"{error_type}: 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
        logging.info(f"{error_type}: 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
    print("="*60)
    logging.info("="*60)
        
    # 在计算指标部分添加以下代码

    # ========== 9.Class-wise ECE, ΔECE between classes, Per-class Accuracy ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("9. Class-wise ECE, ΔECE between classes, Per-class Accuracy")
    logging.info("9. Class-wise ECE, ΔECE between classes, Per-class Accuracy")
    ece_pos, ece_neg = compute_classwise_ece(all_results["step_details"])
    delta_ece = abs(ece_pos - ece_neg) if not np.isnan(ece_pos) and not np.isnan(ece_neg) else np.nan

    acc_pos, acc_neg = compute_per_class_accuracy(all_results["step_details"])

    logging.info(f"正类(正确步骤) ECE: {ece_pos:.4f}")
    logging.info(f"负类(错误步骤) ECE: {ece_neg:.4f}")
    logging.info(f"ΔECE between classes: {delta_ece:.4f}")
    logging.info(f"正类(正确步骤)准确率: {acc_pos:.4f}")
    logging.info(f"负类(错误步骤)准确率: {acc_neg:.4f}")

    print(f"\nClass-wise ECE:")
    print(f"正类(正确步骤) ECE: {ece_pos:.4f}")
    print(f"负类(错误步骤) ECE: {ece_neg:.4f}")
    print(f"ΔECE between classes: {delta_ece:.4f}")
    print(f"正类(正确步骤)准确率: {acc_pos:.4f}")
    print(f"负类(错误步骤)准确率: {acc_neg:.4f}")
    print("="*60)
    logging.info("="*60)

    # ========== 新增指标: Adaptive Binning ECE ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("10. Adaptive Binning ECE")
    logging.info("10. Adaptive Binning ECE")
    adaptive_ece = compute_adaptive_binning_ece(
        np.array([s["stdconf"] for s in all_results["step_details"]]),
        np.array([s["pre_label"] for s in all_results["step_details"]]),
        np.array([s["gt_label"] for s in all_results["step_details"]]),
        n_bins=10
    )

    logging.info(f"Adaptive Binning ECE: {adaptive_ece:.4f}")
    print(f"Adaptive Binning ECE: {adaptive_ece:.4f}")
    print("="*60)
    logging.info("="*60)
    
    # ========== 新增指标: 按模态分组的准确率和ECE ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("11. 按模态分组的准确率和ECE")
    logging.info("11. 按模态分组的准确率和ECE")
    metrics_by_data_type = calculate_metrics_by_data_type(all_results["step_details"])
    print("\n按模态分组的准确率和ECE:")
    logging.info("\n按模态分组的准确率和ECE:")
    for data_type, metrics in metrics_by_data_type.items():
        print(f"模态 {data_type}: 步骤数={metrics['num_steps']}, 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
        logging.info(f"模态 {data_type}: 步骤数={metrics['num_steps']}, 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
    print("="*60)
    logging.info("="*60)
    
    # ========== 综合指标计算 ==========
    # ========== 12. Confidence Stability Score (CSS) ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("12. Confidence Stability Score (CSS)")
    logging.info("12. Confidence Stability Score (CSS)")

    # 计算并输出CSS（整体）
    if 'adversarial_metrics_all' in locals():
        # 整体对抗扰动指标
        css_all, stability_level_all = calculate_css(adversarial_metrics_all)
        
        print("\n整体Confidence Stability Score (CSS):")
        print(f"CSS: {css_all:.4f} - {stability_level_all}")
        logging.info(f"整体CSS: {css_all:.4f} - {stability_level_all}")
        
        # 按扰动类型分组计算CSS
        perturbed_types = ["student_solution_synonym", "student_solution_structure", "perturbed_image"]
        css_results = {}
        
        for perturbed_type in perturbed_types:
            # 过滤当前扰动类型的记录
            type_records = [r for r in all_results["stepwise_conf_records"] if r["perturbed_field"] == perturbed_type]
            
            if not type_records:
                continue
                
            # 计算该扰动类型的对抗指标
            adversarial_metrics_type = calculate_adversarial_metrics(type_records)
            css, stability_level = calculate_css(adversarial_metrics_type)
            css_results[perturbed_type] = (css, stability_level)
        
        print("\n按扰动类型的Confidence Stability Score (CSS):")
        logging.info("\n按扰动类型的CSS:")
        for perturbed_type, (css, stability_level) in css_results.items():
            print(f"{perturbed_type}: CSS = {css:.4f} - {stability_level}")
            logging.info(f"{perturbed_type}: CSS = {css:.4f} - {stability_level}")
    else:
        print("\n没有可用于计算CSS的数据")
        logging.info("\n没有可用于计算CSS的数据")

    print("="*60)
    logging.info("="*60)

    # ========== 13. Confidence Error Sensitivity Score (CESS) ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("13. Confidence Error Sensitivity Score (CESS)")
    logging.info("13. Confidence Error Sensitivity Score (CESS)")

    # 计算并输出CESS
    cess = calculate_cess_from_sorted_delta_conf(sorted_delta_conf)
    if not np.isnan(cess):
        print(f"\nConfidence Error Sensitivity Score (CESS): {cess:.4f}")
        logging.info(f"CESS: {cess:.4f}")
        
        # 解释CESS值
        if cess > 0.6:
            interpretation = "模型能很好地区分正确步骤和各类错误"
        elif cess > 0.3:
            interpretation = "模型能较好地区分正确步骤和各类错误"
        elif cess > 0.1:
            interpretation = "模型区分正确步骤和各类错误的能力一般"
        else:
            interpretation = "模型区分正确步骤和各类错误的能力较弱"
        
        print(f"解释: {interpretation}")
        logging.info(f"解释: {interpretation}")
    else:
        print("\n无法计算CESS（缺少必要数据）")
        logging.info("无法计算CESS（缺少必要数据）")

    print("="*60)
    logging.info("="*60)

    # ========== 14. Confidence Multiview Calibration Score (CMCS) ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("14. Confidence Multiview Calibration Score (CMCS)")
    logging.info("14. Confidence Multiview Calibration Score (CMCS)")

    # 计算并输出CMCS
    if not np.isnan(ece) and not np.isnan(ece_pos) and not np.isnan(ece_neg):
        cmcs, calibration_level = calculate_cmcs(ece, ece_pos, ece_neg)
        
        print(f"\nConfidence Multiview Calibration Score (CMCS): {cmcs:.4f}")
        print(f"校准级别: {calibration_level}")
        logging.info(f"CMCS: {cmcs:.4f}")
        logging.info(f"校准级别: {calibration_level}")
    else:
        print("\n无法计算CMCS（缺少必要的ECE数据）")
        logging.info("无法计算CMCS（缺少必要的ECE数据）")

    print("="*60)
    logging.info("="*60)
    
    # ========== 15. 按科目分组的准确率和ECE ==========
    print("\n" + "="*60)
    logging.info("="*60)
    print("15. 按科目分组的准确率和ECE")
    logging.info("15. 按科目分组的准确率和ECE")
    metrics_by_subject = calculate_metrics_by_subject(all_results["step_details"])
    print("\n按科目分组的准确率和ECE:")
    logging.info("\n按科目分组的准确率和ECE:")
    for subject, metrics in metrics_by_subject.items():
        print(f"科目 {subject}: 步骤数={metrics['num_steps']}, 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
        logging.info(f"科目 {subject}: 步骤数={metrics['num_steps']}, 准确率={metrics['accuracy']:.4f}, ECE={metrics['ECE']:.4f}")
    print("="*60)
    logging.info("="*60)
    
    
    print("\n所有指标计算完成！")
    #-------------------------------------------------------------------------------------------------

    # 统计并输出程序运行总时间
    end_time = time.time()
    elapsed_time = end_time - start_time
    elapsed_h = int(elapsed_time // 3600)
    elapsed_m = int((elapsed_time % 3600) // 60)
    elapsed_s = int(elapsed_time % 60)
    print(f"\n程序运行总时间: {elapsed_h}小时{elapsed_m}分{elapsed_s}秒")
    logging.info(f"程序运行总时间: {elapsed_h}小时{elapsed_m}分{elapsed_s}秒")