from __future__ import annotations

from config import (
    QUERY_TEMPLATE_2,
    QUERY_TEMPLATE_3,
    QUERY_TEMPLATE_4,
    QUERY_TEMPLATE_5,
    CHOICES,
)

import argparse
import json
import os
import gc
import shutil
import importlib.util
from typing import Dict, List, Optional, Tuple, Iterable, Any, Set

import torch
from rich.console import Console
from rich.progress import track
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging
import re
import unicodedata
from rich.markup import escape
import sys
import random

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)


logging.set_verbosity_error()  # 仅显示ERROR信息，忽略WARNING
console = Console()

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
torch.manual_seed(42)
random.seed(42)

USE_FEW_SHOT = False
MODEL_PATH = os.path.join(PROJECT_ROOT, "models")

"""
数据集格式
{
    "id": "",
    "question": "",
    "choices": {"text": ["xxxx", "xxxx", "xxxx", "xxxx"], "label": ["A", "B", "C", "D"]},
    "answer": ""
}
"""

_LETTER = r"[A-E]"

# 正则匹配模式
AFTER = r'(?=$|[\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜·\-–—~〜～*`]|项)'
SEP_CHARS = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～/\\\*`'
# RE_MAIN_PAREN 将在 _tail_has_complete_answer 中动态构建
RE_VARIANTS = [
    # 1) 英文强锚点
    re.compile(rf"(?i)the\s+answer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}"),
    re.compile(rf"(?i)\b(?:final|correct)\s+answer\s*(?:is\s*[:：]?)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}"),  # 允许没有 is 与冒号
    re.compile(rf"(?i)\bthe\s+answer\s+is\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    re.compile(rf"(?i)\banswer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?" + AFTER),
    re.compile(rf"(?i)\banswer\s*[:：]\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?" + AFTER),  # 冒号写法
    re.compile(rf"(?i)\b(?:correct|final)\s*choice\s*(?:is|:)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    re.compile(rf"(?i)\bthe\s+answer\s+is\s*[:：]?\s*(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    # 2) 中文常见
    re.compile(rf"(?i)(?:正确答案|答案)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    re.compile(rf"(?i)(?:故选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    re.compile(rf"(?i)(?:选择(?:答案)?|选(?:择)?)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    re.compile(rf"(?i)(?:选项|选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*项?"+ AFTER),
    # 英文动作动词（避免误命中枚举行）
    re.compile(rf"(?i)\b(?:choose|chosen|pick|picked|select|selected)\s+(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?"+ AFTER),
    # 3) MarkDown
    re.compile(rf"(?i)\b(?:final|correct)?\s*answer\s*[:：]?\s*\**\s*[\[\{{\(\（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\**"+ AFTER),
]

# 整行锚定，一般不会错杀
RE_LASTLINE_FORMAT_FALLBACKS = [
    # **A** / **b**
    re.compile(rf"^\s*\*\*\s*({_LETTER})\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$", re.I),
    # **(A)**
    re.compile(rf"^\s*\*\*\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$", re.I),
    # \boxed{A} / \boxed{\textbf{A}} / \boxed{\mathrm{a}} …
    re.compile(rf"^\s*\\boxed\{{\s*(?:\\(?:textbf|mathbf|mathrm|mathsf|bfseries|boldsymbol)\{{\s*({_LETTER})\s*\}}|({_LETTER}))\s*\}}\s*[\.\!\?\u3002\uFF1F]*\s*$", re.I),
    # \(A\)
    re.compile(rf"^\s*\\\(\s*({_LETTER})\s*\\\)\s*[\.\!\?\u3002\uFF1F]*\s*$", re.I),
    # (A) / [A] / {A} 以及全角括号
    re.compile(rf"^\s*[\(\[\{{（【]\s*({_LETTER})\s*[\)\]\}}）】]\s*[\.\!\?\u3002\uFF1F]*\s*$", re.I),
]

### 增加 incomplete 截断 ###
_SENT_END_RE = re.compile(r'(?:[。！？!?]|\.{3,}|…|\.(?=(?:\s|\u3000|$|[”"\'’）\)\]】》])))[\s\u3000]*[”"\'’）\)\]】》]*[\s\u3000]*')
ANSWER_PHRASES = [
    # === 新格式：直接 ANSWER: {LETTER} ===
    r"(?i)ANSWER\s*:\s*[A-E]",
    r"\banswer\s*[:：]\s*[A-Ea-e]",  # answer: A / answer: B
    r"\bthe\s+answer\s+is\s*[:：]\s*[A-Ea-e]",  # the answer is: A
    r"\bfinal\s+answer\s*[:：]\s*[A-Ea-e]",  # final answer: A
    r"\bcorrect\s+answer\s*[:：]\s*[A-Ea-e]",  # correct answer: A
    r"\banswer\s+is\s*[:：]\s*[A-Ea-e]",  # answer is: A
    r"\bthe\s+answer\s+is\s+[A-Ea-e]",  # the answer is A
    r"\banswer\s+([A-Ea-e])(?=$|\b|\s|[.,;:!?，。！？：；、])",  # answer A (收紧形式，要求独立token)
    # === 新格式：带括号的答案 ===
    r"\banswer\s*[:：]\s*\(?[A-Ea-e]\)?",  # answer: (A)
    r"\bthe\s+answer\s+is\s*[:：]?\s*\(?[A-Ea-e]\)?",  # the answer is: (A)
    r"\banswer\s*[:：]\s*\*{1,2}[A-Ea-e]\*{1,2}",  # answer: **A**
    r"\bthe\s+answer\s+is\s+\*{1,2}[A-Ea-e]\*{1,2}",  # the answer is **A**
    # === 新格式：Markdown/格式化版本 ===
    r"\banswer\s*[:：]\s*\**\s*[A-Ea-e]\s*\**",  # answer: **A**
    r"\bthe\s+answer\s+is\s*\**\s*[A-Ea-e]\s*\**",  # the answer is **A**
    r"\banswer\s*[:：]\s*\[([A-Ea-e])\]",  # answer: [A]
    r"\banswer\s*[:：]\s*\{\s*([A-Ea-e])\s*\}",  # answer: {A}
    r"\banswer\s*[:：]\s*\\boxed\{\s*[A-Ea-e]\s*\}",  # answer: \boxed{A}
    # === 英文通用（收敛） - 保留原有模式 ===
    r"\b(?:so|thus|therefore)\b.*?\b(?:final\s+|correct\s+)?(?:answer|result|solution|choice)\s*(?:is|are|:|：)",  # 有因果/收束词
    r"\b(?:the\s+)?(?:final\s+|correct\s+)?answer\s*(?:is|are|:|：)",  # the/final/correct answer ...
    r"\banswer\s*(?:is|:|：)",  # answer is / answer:
    r"\b(?:final|correct)\s+choice\s*(?:is|:|：)",  # final/correct choice ...
    r"\bthe\s+answer\s+is\s*(?:option\s*)?",  # the answer is (option) ...
    r"\b(?:result|solution)\s*[:：]",  # 只在冒号形式放宽 result/solution
    # === 中文通用（无需前缀空格） ===
    r"(?:因此|所以|综上(?:所述)?|故|可知)\s*(?:最终\s*)?(?:答案|结果|选项|选择)\s*(?:是|为|:|：)",
    r"(?:最终\s*)?(?:答案|结果)\s*(?:是|为|:|：)",
    # === 简洁形式（无动词） ===
    r"(?:answer|答案)\s*[:：]",
    # === 选择题强语义结尾（收紧） ===
    # 仅保留"故选/应选"
    r"(?:故\s*选|应\s*选)\s*[A-Ea-e](?:\s*项|(?:\s*选项)?)?",
    r"(?:故\s*选|应\s*选)\s*(?:[0-9]+|[一二三四五六七八九十])(?:\s*项|(?:\s*选项)?)?",
]
ANSWER_PHRASES_RE = [re.compile(ANSWER_PHRASES[0], re.I | re.S)] + [re.compile(p, re.I) for p in ANSWER_PHRASES[1:]]


# 如果尾巴有完整答案，则不截断（理论上发生几率较小）
def _tail_has_complete_answer(tail: str, choices: str = "ABCDE") -> bool:
    # 动态构建主形态正则表达式，保持与RE_VARIANTS兼容
    dynamic_main_paren = re.compile(rf"(?i)ANSWER\s*:\s*([{choices}]){AFTER}")

    if dynamic_main_paren.search(tail):  # 严格主形态
        return True
    for pat in RE_VARIANTS:  # 变体里只要有非空捕获
        m = pat.search(tail)
        if m and next((g for g in m.groups() if g), None):
            return True
    return False

def _contains_answer_phrase(text: str) -> bool:
    return any(p.search(text) for p in ANSWER_PHRASES_RE)


def safe_rich_print(text: str, style: str = "red") -> None:
    """
    安全地使用Rich库打印文本，转义可能导致MarkupError的特殊字符
    """
    # 转义方括号和其他可能导致问题的字符
    escaped_text = str(text).replace("[", "\\[").replace("]", "\\]")
    console.print(f"[{style}]{escaped_text}[/{style}]")


def parse_input(example):
    question = example["question"]
    texts = example["choices"]["text"]
    labels = example["choices"]["label"]

    if len(labels) == 2:
        prompt = QUERY_TEMPLATE_2.format(question=question, textA=texts[0], textB=texts[1])
    elif len(labels) == 3:
        prompt = QUERY_TEMPLATE_3.format(question=question, textA=texts[0], textB=texts[1], textC=texts[2])
    elif len(labels) == 4:
        prompt = QUERY_TEMPLATE_4.format(
            question=question,
            textA=texts[0],
            textB=texts[1],
            textC=texts[2],
            textD=texts[3],
        )
    elif len(labels) == 5:
        prompt = QUERY_TEMPLATE_5.format(
            question=question,
            textA=texts[0],
            textB=texts[1],
            textC=texts[2],
            textD=texts[3],
            textE=texts[4],
        )
    else:
        raise ValueError(f"Unsupported number of choices: {len(labels)}")
    return prompt


# 去掉不完整的结尾句子
def trim_incomplete_sentence(ans: str, choices: str = "ABCDE") -> str:
    if not ans:
        return ans
    matches = list(_SENT_END_RE.finditer(ans))
    if not matches:
        return ans
    last_end = matches[-1].end()

    # print(ans[:last_end][:50], ans[last_end:])

    # 只有在最后一句不完整的话中包含可能会被匹配的候选项并且前面的内容中也出现了候选项时才截断
    if _contains_answer_phrase(ans[last_end:]) and _contains_answer_phrase(ans[:last_end]):
        if _tail_has_complete_answer(ans[last_end:], choices):
            return ans
        return ans[:last_end].strip()
    return ans


# 将字符串转换为规范形式，去除空白，压缩空白，保留换行
def _normalize(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("（", "(").replace("）", ")")
    for z in ("\u200b", "\xa0", "\u202f", "\ufeff", "\u200d", "\u2060", "\t"):
        s = s.replace(z, " ")
    # 仅压缩空白但保留换行：避免破坏“最后一行”逻辑
    s = re.sub(r"[^\S\r\n]+", " ", s)
    return s


def _strip_latex_noise_oneline(s: str) -> str:
    # 去掉行内/陈列数学环境定界符
    s = re.sub(r"\\\[|\\\]|\\\(|\\\)|\$\$?", " ", s)
    # 去掉常见 \frac / \dfrac / \tfrac 形式
    s = re.sub(r"\\[dt]?frac\s*\{[^{}]*\}\s*\{[^{}]*\}", " ", s)
    # 去掉通用 \cmd{...}（保留 \boxed{...} 等已在上面兜底过的）
    s = re.sub(r"\\(?!boxed\b)[A-Za-z]+\s*\{([^{}]*)\}", r" \1 ", s)
    return s


# 兜底：看最后三行
def extract_final(text: str, option_answer_text: str):
    """
    兜底：看最后三行
    先试 the answer is (X)，再试纯字母行 '(D)' / 'D'，最后取最后一个 A–J。
    """
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    if not lines:
        return None

    # 检查末尾3行（如果长度允许的话）
    for idx, ln in enumerate(reversed(lines[max(-len(lines), -3) :])):
        last = _normalize(ln)

        # 末行快速重试若干锚点（含无括号）
        # 再跑一遍强/中锚点
        for pat in RE_VARIANTS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        for pat in RE_LASTLINE_FORMAT_FALLBACKS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        if option_answer_text:
            cand = extract_latex_non_letter(last)
            if cand:
                if _normalize_for_compare(cand) == _normalize_for_compare(option_answer_text):
                    # 提取到的“非字母选项内容”与正确选项文本精确匹配
                    return cand

        last_sanit = _strip_latex_noise_oneline(last)

        # 整行只包含一个字母（可带括号和标点）
        m = re.match(rf"^\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*[\.\!\?\u3002\uFF1F]*\s*$",last_sanit, re.I)
        if m:
            return m.group(1).upper()

        # 在行内查找被分隔符包围的独立字母，取最后一个
        SEP_CHARS_RIGHT = SEP_CHARS
        SEP_CHARS_LEFT = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～*`'  # 去掉了 / 和 \
        pattern = rf"(?i)(?:^|[{SEP_CHARS_LEFT}])\s*[\(\[\{{（【]?\b({_LETTER})\b\s*[\)\]\}}）】]?(?!\s*[/\\=^\(])(?=$|[{SEP_CHARS_RIGHT}]|项)"
        matches = list(re.finditer(pattern, last_sanit, re.I))
        if matches:
            return matches[-1].group(1).upper()

    return None

# 标准化字符串，去除空白，压缩空白，保留换行，方便答案比较
def _normalize_for_compare(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize("NFKC", s)
    s = re.sub(r"\s+", " ", s)  # 压缩连续空白符
    return s.strip()

def _latex_clean_payload(s: str) -> str:
    if s is None:
        return ""
    cur = s

    # —— 分数直接删掉（含 \frac/\tfrac/\dfrac 与 纯文本 a/b 形式）——
    # 1) 去掉 \frac{...}{...} / \tfrac / \dfrac
    cur = re.sub(r"\\[td]?frac\s*\{\s*[^{}]*?\s*\}\s*\{\s*[^{}]*?\s*\}", " ", cur)
    # 2) 去掉纯文本分数 a/b（含可选正负号与空格）
    cur = re.sub(r"[-+−]?\s*\d+\s*/\s*\d+", " ", cur)

    # 统一 Unicode 负号
    cur = cur.replace("\u2212", "-")

    # 递归剥外层常见文本/样式宏壳
    prev = None
    while prev != cur:
        prev = cur
        cur = re.sub(
            r"\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}",
            r"\1", cur
        )

    # 去掉数学定界符与剩余裸命令
    cur = re.sub(r"(\\\[|\\\]|\$\$|\\\(|\\\))", " ", cur)
    cur = re.sub(r"\\[A-Za-z]+", " ", cur)  # \left \right 等英文命令
    cur = cur.replace("{", " ").replace("}", " ")

    # 规整空白
    cur = _normalize_for_compare(cur)

    # 去掉引导语
    cur = re.sub(r"(?i)\b(?:the\s+answer\s+is|答案(?:为|是)?)\b\s*[:：]?\s*", "", cur).strip()

    # 若整体仅被一对括号包裹，剥掉
    if re.fullmatch(r"[\(\[\{（【〔〈《]\s*[^()\[\]{}（）【】〔〕〈〉《》]*\s*[\)\]\}）】〕〉》]", cur,):
        cur = cur[1:-1].strip()
    return cur


# 抽“非字母选项内容”。采用“后出现优先”：候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
def extract_latex_non_letter(text: str) -> str | None:
    """
    抽“非字母选项内容”。采用“后出现优先”：
    候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
    """
    if not text:
        return None

    t = text
    candidates: list[str] = []

    # 1) \boxed{...}（可能出现多次，先收集，后出现优先）
    last_boxed = None
    for m in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", t):
        last_boxed = m.group(1)
    if last_boxed:
        candidates.append(last_boxed)

    # 2) 常见文本/样式宏
    last_text_like = None
    for macro in ("text", "textbf", "mathrm", "mathbf", "mathsf", "operatorname", "textrm", "rm"):
        for m in re.finditer(rf"\\{macro}\s*\{{\s*([^{{}}]+?)\s*\}}", t):
            last_text_like = m.group(1)
    if last_text_like:
        candidates.append(last_text_like)

    # 3) “the answer is ...” + 数字（非分数；分数已被清洗逻辑去掉）
    m_num = re.search(r"(?i)the\s+answer\s+is\s*[:：]?\s*\(?\s*([-+−]?\d+(?:\.\d+)?)", t)
    if m_num:
        candidates.append(m_num.group(1))

    # 4) 数学环境 \( ... \)、\[ ... \]、$$ ... $$
    for m in re.finditer(r"(?:\\\(|\\\[|\$\$)(.*?)(?:\\\)|\\\]|\$\$)", t, re.S):
        inner = m.group(1)
        # math 内部 \boxed{...}
        for m2 in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", inner):
            candidates.append(m2.group(1))
        # 其次取最后出现的“数字”（非分数）
        nums = re.findall(r"[-+−]?\d+(?:\.\d+)?", inner)
        if nums:
            candidates.append(nums[-1])

    # 过滤：单个 A–J 交给字母流程
    try:
        LETTER = _LETTER  # 如果你的项目里已有
    except NameError:
        LETTER = r"[A-J]"

    for raw in reversed(candidates):  # “后出现优先”
        s = _latex_clean_payload(raw)
        if not s:
            continue
        if re.fullmatch(rf"\(?\s*{LETTER}\s*\)?", s, re.I):
            continue
        return s  # 命中非字母选项
    return None


# 去除字符串末尾的标点符号和空白字符
def _strip_trailing_punct(s: str) -> str:
    # 去掉常见句尾标点（英文/中文）
    return re.sub(r"[\.。!,，；;：:\s]+$", "", s or "").strip()


def extract_arc_answer(
    text: str, choices: str = "ABCDE", option_answer_text: str = None
) -> Optional[str]:
    """
    从模型输出文本中抽取 ARC 答案字母。
    支持动态选项范围，结合了现有MMLU_PRO正则表达式和OpenCompass的实现。
    保持原有处理逻辑：先从末尾开始逐行匹配，然后回退到最后一行查找独立字母。
    """
    if not text or not text.strip():
        return None

    # 先做标准化，保留大小写信息（配合 re.IGNORECASE）
    text = _normalize(text).strip()

    # --- Step 0: 开头单字母特判，比如 "C. xxx", "B) xxx", "D：xxx" ---
    # 只在"第一个非空字符就是选项字母 + 可选标点"时触发，
    # 不会误把 "Answer: C..." 里的 A 当成选项
    prefix_pat = rf"^\s*([{re.escape(choices)}])\s*[\.\)\]:：）】》]?(?:\s|$)"
    m0 = re.match(prefix_pat, text, re.IGNORECASE)
    if m0:
        return m0.group(1).upper()

    # 展开常见文本/样式宏为纯文本，防止误杀 \text{A} / \boxed{A} 等
    text = re.sub(
        r"\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}",
        r"\1",
        text,
    )

    # 构建动态的正则表达式模式列表
    patterns = [
        # 中文常见模式
        f"答案是?\s*([{choices}])",
        f"答案是?\s*：\s*([{choices}])",
        f"答案是?\s*:\s*([{choices}])",
        f"答案选项应?该?是\s*([{choices}])",
        f"答案选项应?该?为\s*([{choices}])",
        f"答案应该?是\s*([{choices}])",
        f"答案应该?选\s*([{choices}])",
        f"答案选项为?\s*：\s*([{choices}])",
        f"答案选项为?\s+\(?\*?\*?([{choices}])\*?\*?\)?",
        f"答案选项是?\s*:\s*([{choices}])",
        f"答案为\s*([{choices}])",
        f"答案选\s*([{choices}])",
        f"选择?\s*([{choices}])",
        f"故选?\s*([{choices}])",
        # 逻辑判断模式
        f"只有选?项?\s?([{choices}])\s?是?对",
        f"只有选?项?\s?([{choices}])\s?是?错",
        f"只有选?项?\s?([{choices}])\s?不?正确",
        f"只有选?项?\s?([{choices}])\s?错误",
        f"说法不?对选?项?的?是\s?([{choices}])",
        f"说法不?正确选?项?的?是\s?([{choices}])",
        f"说法错误选?项?的?是\s?([{choices}])",
        f"([{choices}])\s?是正确的",
        f"([{choices}])\s?是正确答案",
        f"选项\s?([{choices}])\s?正确",
        # 因果连接模式
        f"所以答\s?([{choices}])",
        f"所以\s?([{choices}])[.。$]?$",
        f"所有\s?([{choices}])[.。$]?$",
        f"[\s，,：:][故即]([{choices}])[。\.]?",
        f"[\s，,：:]因此([{choices}])[。\.]?",
        f"[是为。]\s?([{choices}])[。\.]?",
        f"因此\s?([{choices}])[。\.]?",
        f"显然\s?([{choices}])[。\.]?",
        # 英文模式（来自现有代码）
        rf"(?i)the\s+answer\s+is\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\b(?:final|correct)\s+answer\s*(?:is\s*[:：]?)?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\bthe\s+answer\s+is\s*[:：]?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\banswer\s+is\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\banswer\s*[:：]\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\b(?:correct|final)\s*choice\s*(?:is|:)?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)\bthe\s+answer\s+is\s*[:：]?\s*(?:option\s*)?[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)(?:正确答案|答案)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)(?:故选)\s*[:：]?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)(?:选择(?:答案)?|选(?:选择)?)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        rf"(?i)(?:选项|选)\s*[:：]?\s*[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?\s*项?.*",
        rf"(?i)\b(?:choose|chosen|pick|picked|select|selected)\s+(?:option\s*)?[\(\[\{{（【]?\s*([{choices}])\s*[\)\]\}}）]?.*",
        # 更多英文模式（来自OpenCompass）
        f"(?i)ANSWER\s*:\s*([{choices}])",
        f"[Tt]he answer is:?\s+\(?([{choices}])\)?",
        f"[Tt]he answer is:?\s+\(?\*?\*?([{choices}])\*?\*?\)?",
        f"[Tt]he answer is option:?\s+\(?([{choices}])\)?",
        f"[Tt]he correct answer is:?\s+\(?([{choices}])\)?",
        f"[Tt]he correct answer is option:?\s+\(?([{choices}])\)?",
        f"[Tt]he correct answer is:?.*?boxed{{([{choices}])}}",
        f"[Tt]he correct option is:?.*?boxed{{([{choices}])}}",
        f"[Tt]he correct answer option is:?.*?boxed{{([{choices}])}}",
        f"[Tt]he answer to the question is:?\s+\(?([{choices}])\)?",
        # 独立选项模式
        f"^选项\s?([{choices}])",
        f"^([{choices}])\s?选?项",
        # r"1.\s?(.*?)",
        f"1.\s?([{choices}])[.。$]?",
        # 通用答案模式（回退）
        r"答案是\s?(\S+)(?:。|$)",
        r"答案应该是\s?(\S+)(?:。|$)",
        r"答案为\s?(\S+)(?:。|$)",
        # ⚠️ 不再使用极宽松的单字母兜底，避免误匹配 answer / correct 里的字母
    ]

    # 合并所有模式
    all_patterns = patterns

    # 1) 保持原有逻辑：逐行扫描，优先检查最后一行 → 第一行，减少误判
    lines = text.strip().splitlines()
    for line in reversed(lines):
        line = line.strip()
        if not line:
            continue

        # 对每一行使用所有模式进行匹配
        for pattern in all_patterns:
            try:
                match = re.search(pattern, line, re.IGNORECASE)
                if match:
                    # 优先返回明确捕获组
                    if match.groups():
                        # 注意要拿到 group 的 index，这样才能用 match.start/ end
                        for gi, group in enumerate(match.groups(), start=1):
                            if not group:
                                continue
                            g = group.strip()
                            # 只处理单个字母的情况
                            if len(g) == 1 and g.upper() in choices:
                                start = match.start(gi)
                                end = match.end(gi)
                                before = line[start - 1] if start > 0 else ""
                                after = line[end] if end < len(line) else ""
                                # 如果前后是英文字母/数字，认为这是单词的一部分，跳过
                                # 使用更精确的检查：ASCII 字母数字 [A-Za-z0-9]
                                if (before and before.isascii() and before.isalnum()) or \
                                   (after and after.isascii() and after.isalnum()):
                                    continue
                                return g.upper()
                            # 多字符的 group（比如 "B选项"）让它走下面 matched_text 的 letter_pattern 兜底
                    # 如果没有明确的组或捕获组被跳过，尝试在整个行中查找独立字母
                    # 使用完整的 line 而不是 matched_text，以确保能找到行内真正的独立字母
                    letter_pattern = rf"(?<![A-Za-z0-9])([{re.escape(choices)}])(?![A-Za-z0-9])"
                    candidates = list(re.finditer(letter_pattern, line, re.IGNORECASE))
                    if candidates:
                        # 1) 先从后往前找，优先选择最后出现的大写字母
                        for c in reversed(candidates):
                            if c.group(1).isupper():
                                return c.group(1).upper()
                        # 2) 如果没有大写字母，保持原逻辑：取最后一个候选（可能是小写）
                        return candidates[-1].group(1).upper()
            except re.error:
                # 忽略正则表达式错误，继续下一个模式
                continue

    # 2) 回退方案：看最后三行，复用 mmlu_pro 的逻辑
    return extract_final(text, option_answer_text)


def arc_challenge_parse(pred, true, hit_limit, choices, option_answer_text):
    extracted = extract_arc_answer(pred, choices, option_answer_text)
    if not extracted or extracted.strip() == "":
        if hit_limit is True:
            return "Incomplete", False
        return None, False

    extracted_norm = extracted.strip().upper()

    # 额外保险：如果抽到的是英文字母，但不在合法选项里，当作没抽到
    # 避免出现诸如 E 这种"题目根本没给这个选项"的结果
    # 即使 extract_final 抽到超出选项范围的字母也会在这里被拦截
    if (
        len(extracted_norm) == 1
        and extracted_norm.isalpha()
        and extracted_norm not in choices
    ):
        if hit_limit is True:
            return "Incomplete", False
        return None, False

    is_correct = (
        true.strip().upper() == extracted_norm
        or _normalize_for_compare(_strip_trailing_punct(extracted.lower()))
        == _normalize_for_compare(_strip_trailing_punct((option_answer_text or "").lower()))
    )
    return extracted, is_correct


# 清空当前目录下的所有文件和子目录
def clear_dir(dir_path: str):
    if not os.path.isdir(dir_path):
        return
    for filename in os.listdir(dir_path):
        file_path = os.path.join(dir_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)  # 删除文件或符号链接
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # 删除子目录
        except Exception as e:
            print(f"删除失败: {file_path}, 原因: {e}")


def get_item_id(item: Dict[str, Any]) -> str:
    val = item.get("unique_id", None)  # MATH 用 unique_id
    if val is None or val == "":
        val = item.get("question_id")  # MMLU_PRO 用 question_id
        if val is None or val == "":
            val = item.get("id")  # BBH 和 arc_challenge 用 id
    if val is None:
        return ""
    return str(val).strip()  # "0" 也会被当作有效 ID


# 自动化 batch 估计
def _load_module(module_path: str, module_name: str):
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot import {module_name} from {module_path}")
    mod = importlib.util.module_from_spec(spec)
    import sys as _sys

    _sys.modules[module_name] = mod
    spec.loader.exec_module(mod)
    return mod


PTS_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "prompt_token_states_post.py")
VRAM_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "vram_estimate_post.py")
PTS_POST = _load_module(PTS_PATH, "prompt_token_states_post")
VRAM_POST = _load_module(VRAM_PATH, "vram_estimate_post")
# 断言模块是否加载成功
assert hasattr(PTS_POST, "token_calculation"), "prompt_token_states_post 模块缺少 token_calculation"
assert hasattr(VRAM_POST, "batch_estimation"), "vram_estimate_post 模块缺少 batch_estimation"


# eos 规范化
def _normalize_to_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]


class CheckpointManager:
    """
    将进度以“已处理样本 ID 集合”的方式持久化。
    load_existing=False 时，会忽略并覆盖既有 checkpoint，实现“从头跑但仍写断点”的模式。
    """

    def __init__(self, path: str, load_existing: bool = True):
        self.path = path
        self.processed_ids: Set[str] = set()
        self.completed: bool = False
        self.total_samples: int = 0
        if load_existing:
            self._load()
        else:
            # 显式选择不加载旧 ckpt：如有旧文件，将在第一次 _atomic_save 时被覆盖
            pass

    def _load(self) -> None:
        if os.path.exists(self.path):
            try:
                with open(self.path, "r", encoding="utf-8") as f:
                    data = json.load(f)
                self.processed_ids = set(map(str, data.get("processed_ids", [])))
                self.completed = bool(data.get("completed", False))
                self.total_samples = int(data.get("total_samples", 0))
            except Exception as e:
                print(f"[ckpt] 读取失败，忽略旧文件：{self.path} ({e})")

    def reset(self) -> None:
        """清空内存态的进度；下次 _atomic_save 将覆盖磁盘文件。"""
        self.processed_ids.clear()
        self.completed = False
        # total_samples 保留，后续 set_total 会更新
        # 不立即落盘，避免误覆盖，等到首次写入再覆盖

    def set_total(self, n: int) -> None:
        if self.total_samples != n:
            self.total_samples = int(n)
            self._atomic_save()

    def mark_batch_processed(self, ids: Iterable[str]) -> None:
        for _id in ids:
            sid = str(_id).strip() if _id is not None else ""
            if sid:
                self.processed_ids.add(sid)
        if self.total_samples > 0 and len(self.processed_ids) >= self.total_samples:
            self.completed = True
        self._atomic_save()

    def is_processed(self, _id: Any) -> bool:
        sid = str(_id).strip() if _id is not None else ""
        return bool(sid) and (sid in self.processed_ids)

    def _atomic_save(self) -> None:
        tmp = self.path + ".tmp"
        data = {
            "processed_ids": sorted(self.processed_ids),
            "completed": self.completed,
            "total_samples": self.total_samples,
        }
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(tmp, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False)
        os.replace(tmp, self.path)


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        try:
            attn_implementation = "flash_attention_2"
            if "gemma" in model_path or "c4ai" in model_path:
                attn_implementation = "sdpa"
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                # use_flash_attn=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation=attn_implementation
            )
        except Exception as e:
            console.print("[yellow]Warning: Failed to load model with auto device map and flash-attention[/yellow]")
            # 防止某些模型不支持flash-attention
            try:
                # 防止某些模型不支持flash-attention
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )
            except Exception as e2:
                console.print(f"[yellow]Warning: Failed to load model with auto device map: {escape(str(e2))}[/yellow]")
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map=None,  # 单设备
                ).to(self.device)

        self.model.to(self.device)
        self.model.eval()
        self.layer_map = self._build_layer_index()
        self.calculate_right_padding_length = self._calculate_right_padding_length

        # hook相关属性
        self.hooks = []
        self.captured_hidden_states = {}
        self.is_batch_mode = False
        self.current_sample_idx = 0

        # 调试属性
        self.debug_mode = False
        self.debug_info = {}

    def _build_layer_index(self) -> Dict[str, int]:
        n_layers = self.model.config.num_hidden_layers
        mapping = {
            "middle": n_layers // 2,
            "last": n_layers - 1,
            "second_last": n_layers - 1,
            "quarter": n_layers // 4,  # 1/4层
            "three_quarters": (3 * n_layers) // 4,  # 3/4层
        }
        console.print(
            f"总层数: {n_layers} → quarter={mapping['quarter']}, middle={mapping['middle']}, "
            f"three_quarters={mapping['three_quarters']}, second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping

    def _calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0
        pad_id = self.tokenizer.pad_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

    # 钩子注册
    def _register_hooks(self, needed_layers: List[str]) -> None:
        """注册forward hooks来捕获指定层的隐状态"""
        self._clear_hooks()  # 先清除之前的hooks
        self.captured_hidden_states = {k: [] for k in needed_layers}

        def create_hook(layer_name: str):
            def hook_fn(module, _, output):

                if isinstance(output, tuple):
                    hidden_state = output[0]  # 通常隐状态是第一个元素
                else:
                    hidden_state = output

                # 复制到CPU后立即删除GPU引用，但使用更安全的方式
                hidden_state_cpu = hidden_state.detach().to("cpu")

                # 显式删除GPU引用以节省显存，但要确保Flash Attention已经完成计算
                if hidden_state.is_cuda:
                    del hidden_state

                # if self.debug_mode:
                #     console.print(f"[cyan]{layer_name} hidden state shape: {hidden_state_cpu.shape}[/cyan]")

                # 存储隐藏状态 - 不需要步骤计数，按顺序存储即可
                self.captured_hidden_states[layer_name].append(hidden_state_cpu)

            return hook_fn

        # 根据模型架构获取正确的层
        if hasattr(self.model, "model"):  # 如 LlamaForCausalLM
            layers = self.model.model.layers
        elif hasattr(self.model, "transformer"):  # 如 GPT类模型
            layers = self.model.transformer.h
        else:
            # 尝试其他可能的属性名
            for attr in ["layers", "h", "decoder_layers"]:
                if hasattr(self.model, attr):
                    layers = getattr(self.model, attr)
                    break
            else:
                raise AttributeError("无法找到模型的层结构")

        # 修复问题1：直接根据layer_name计算索引，避免重复计算
        for layer_name in needed_layers:
            if layer_name == "last":
                layer_idx = len(layers) - 1
            elif layer_name == "second_last":
                layer_idx = len(layers) - 2
            elif layer_name == "middle":
                layer_idx = len(layers) // 2
            elif layer_name == "quarter":
                layer_idx = len(layers) // 4
            elif layer_name == "three_quarters":
                layer_idx = (3 * len(layers)) // 4
            else:
                # 如果是其他层名，可以从layer_map获取
                layer_idx = self.layer_map.get(layer_name, 0)

            if 0 <= layer_idx < len(layers):
                hook = layers[layer_idx].register_forward_hook(create_hook(layer_name))
                self.hooks.append(hook)
                # console.print(f"注册hook for {layer_name} (layer {layer_idx})")

    def _clear_hooks(self) -> None:
        """清除所有hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        # 注意：这里不清除captured_hidden_states，因为后续还需要使用
        # self.captured_hidden_states.clear()

    # 方便统一处理chat/base模型
    def _build_input_ids(self, question: str) -> Tuple[torch.Tensor, list[int], str]:
        """
        统一构建 input_ids, 返回:
        - input_ids  : (1, seq_len) tensor on self.device
        - prompt_ids : list[int] 同一内容, 用于后续定位
        - prompt_txt : str, decode 后的完整 prompt 文本
        """
        # chat / instruct 模型：tokenizer 支持 chat template
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            prompt_struct = [{"role": "user", "content": question}]
            if "qwen3" in self.model_key.lower():
                try:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=True,
                    )
                except Exception as e:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
            else:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
        else:
            batch = self.tokenizer(
                question,
                add_special_tokens=True,  # 先让tokenizer自己加他心目中的特殊符号
                return_tensors="pt",
                truncation=False,
            )
            input_ids = batch["input_ids"]  # shape: (1, L)
            bos_id = self.tokenizer.bos_token_id
            if bos_id is not None and input_ids[0, 0].item() != bos_id:
                bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                input_ids = torch.cat([bos_tensor, input_ids], dim=1)

            encoded = input_ids

        embed_dev = self.model.get_input_embeddings().weight.device
        encoded = encoded.to(embed_dev)
        prompt_ids = encoded[0].tolist()
        prompt_txt = self.tokenizer.decode(encoded[0], skip_special_tokens=False)
        return encoded, prompt_ids, prompt_txt

    def _build_batch_input_ids(self, questions: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[str], torch.Tensor]:
        """
            批量构建 input_ids, 返回:
            - input_ids  : (batch_size, max_seq_len) tensor on self.device
            - padded_attention_mask : 填充的 attention_mask
            - prompt_ids_list : List[List[int]] 每个样本的prompt_ids
            - prompt_txts : List[str] 每个样本的prompt文本
            - left_pad_lens : 左填充长度
        """
        self.tokenizer.padding_side = "left"  # 左填充方便后续处理（直接取最后一个非pad token的隐状态即可）
        # 统一设置pad_token，确保与后续处理一致
        if self.tokenizer.pad_token_id is None:
            # 添加专用的pad_token
            try:
                special_tokens_dict = {"pad_token": "<|pad|>"}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    # 将新添加的pad embedding置零
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except Exception:
                # 如果添加失败，使用eos_token作为pad_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        # 同步到 model / generation_config
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        # self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id

        prompt_texts: List[str] = []
        add_special = True
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            add_special = False
            for q in questions:
                messages = [{"role": "user", "content": q}]
                if "qwen3" in self.model_key.lower():
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,  # 先拿到纯文本模板展开
                        add_generation_prompt=True,
                        enable_thinking=True,
                    )
                else:
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                prompt_texts.append(txt)
        else:
            # 无 chat 模板时，直接用原问题文本
            prompt_texts = list(questions)

        batch_enc = self.tokenizer(
            prompt_texts,
            padding="longest",  # 自动对齐到本 batch 最长
            return_tensors="pt",
            return_attention_mask=True,
            add_special_tokens=add_special,  # base model 需要添加特殊符号, chat model 不需要
            truncation=False,  # 避免静默截断
        )
        padded_input_ids = batch_enc["input_ids"]
        padded_attention_mask = batch_enc["attention_mask"]

        prompt_ids_list: List[List[int]] = [
            self.tokenizer.encode(txt, add_special_tokens=add_special)
            for txt in prompt_texts
        ]

        # 转换为tensor并移到设备
        embed_dev = self.model.get_input_embeddings().weight.device
        padded_input_ids = padded_input_ids.to(embed_dev)
        padded_attention_mask = padded_attention_mask.to(embed_dev)

        # 根据 attention_mask 填充长度
        _, T = padded_attention_mask.shape
        prompt_lens = padded_attention_mask.sum(dim=1)  # 每个样本真实长度
        left_pad_lens = T - prompt_lens

        return (padded_input_ids, padded_attention_mask, prompt_ids_list, prompt_texts, left_pad_lens,)

    # 特征提取
    def _extract_features_from_hooks(
        self,
        prompt_len: int,
        output_len: int,
        needed_layers: List[str],
        sample_idx: int = 0,
        hit_limit: bool = False,
    ) -> Dict[str, Dict[str, torch.Tensor]]:
        """从hooks收集的隐状态中提取特征"""
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in needed_layers}

        if self.debug_mode:
            console.print("[bold magenta]=== 提取特征调试信息 ===[/bold magenta]")
            console.print(f"Prompt长度: {prompt_len}, 输出长度: {output_len}, 样本索引: {sample_idx}")
            console.print(f"捕获的隐状态层数: {len(self.captured_hidden_states)}")
            for k, states in self.captured_hidden_states.items():
                console.print(f"  {k}: {len(states)} 个状态")

        if output_len <= 0:
            console.log("[yellow]Warn: no output tokens generated.[/yellow]")
            hidden_dim = self.model.config.hidden_size
            _dtype = next(self.model.parameters()).dtype
            _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
            return layer_features

        for k in needed_layers:
            if k not in self.captured_hidden_states:
                continue

            layer_states = self.captured_hidden_states[k]
            if not layer_states:
                console.log(f"[yellow]Warn: no hidden states captured for {k}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 在自回归生成中，每个前向传播都会产生隐藏状态
            # 第一个前向传播：处理整个prompt + 生成第一个token
            # 后续前向传播：每次处理之前的序列 + 新生成的token

            # 获取第一个前向传播的结果（包含prompt的处理）
            first_step_state = layer_states[0]

            if self.debug_mode:
                console.print(f"[cyan]处理 {k} 层的第一个状态: shape={first_step_state.shape}[/cyan]")

            # 处理张量维度 - 更健壮的维度处理
            if first_step_state.dim() == 3:
                # 批处理模式: [batch_size, seq_len, hidden_dim]
                first_step_hidden = first_step_state[sample_idx]  # [seq_len, hidden_dim]
                if self.debug_mode:
                    console.print(f"[cyan]批处理模式，提取样本 {sample_idx}: {first_step_hidden.shape}[/cyan]")
            elif first_step_state.dim() == 2:
                # 单样本模式: [seq_len, hidden_dim]
                first_step_hidden = first_step_state
                if self.debug_mode:
                    console.print(f"[cyan]单样本模式: {first_step_hidden.shape}[/cyan]")
            else:
                # 一维情况，可能是错误的
                console.print(f"[red]Unexpected hidden state dimension: {first_step_state.dim()}[/red]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 更准确的prompt处理
            if self.is_batch_mode:
                # 在批处理模式中，需要处理左侧padding
                # 第一个前向传播包含了padding + prompt
                # 我们需要提取真正的prompt部分
                if first_step_hidden.shape[0] >= prompt_len:
                    # 从右侧提取prompt_len长度的序列（因为是左padding）
                    actual_prompt_hidden = first_step_hidden[-prompt_len:]
                    if self.debug_mode:
                        console.print(f"[cyan]批处理模式：从右侧提取prompt，原始长度={first_step_hidden.shape[0]}, prompt长度={prompt_len}[/cyan]")
                else:
                    # 如果长度不够，使用全部
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print("[yellow]警告：批处理模式长度不足，使用全部序列[/yellow]")
            else:
                # 单样本模式，第一个前向传播的长度应该等于prompt_len
                if first_step_hidden.shape[0] >= prompt_len:
                    actual_prompt_hidden = first_step_hidden[:prompt_len]
                    if self.debug_mode:
                        console.print(f"[cyan]单样本模式：提取前{prompt_len}个token[/cyan]")
                else:
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print("[yellow]警告：单样本模式长度不足[/yellow]")

            # 收集输出部分的隐状态
            output_states = []

            # 从第一个前向传播中提取第一个生成token的表示
            # 这是prompt处理后的最后一个位置的隐状态，这部分暂定，感觉加不加影响不大，虽然CoE里面加了

            # if len(actual_prompt_hidden) > 0:
            #     # 修复：第一个生成token的隐状态应该是prompt最后一个位置的隐状态
            #     output_states.append(actual_prompt_hidden[-1:])  # [1, hidden_dim]

            # 处理后续的前向传播结果
            # 每个后续的前向传播都会产生一个新的token的隐状态
            for i in range(1, len(layer_states)):
                # if len(output_states) >= output_len:
                #     break  # 已经收集够了

                step_state = layer_states[i]

                # 处理维度
                if step_state.dim() == 3:
                    # 批处理模式: [batch_size, seq_len, hidden_dim]
                    step_hidden = step_state[sample_idx]  # [seq_len, hidden_dim]
                elif step_state.dim() == 2:
                    # 单样本模式: [seq_len, hidden_dim]
                    step_hidden = step_state
                else:
                    continue

                # 新生成的token的隐状态在序列的最后一个位置
                if len(step_hidden) > 0:
                    output_states.append(step_hidden[-1:])  # [1, hidden_dim]

            if self.debug_mode:
                console.print(f"[yellow]对齐检查: output_len={output_len}, captured={len(output_states)}[/yellow]")

            # 计算真正最后一个 token 的隐状态
            if hit_limit:
                effective_len = min(output_len, len(output_states))
            else:
                effective_len = min(output_len - 1, len(output_states))

            # 考虑极短输出（output_len=1）的情况
            if output_len == 1:
                # 此时输出压根没有隐状态，所以直接用prompt的最后一个token的隐状态
                output_hidden = actual_prompt_hidden[-1:].contiguous()
                # avg_without_prompt：没有生成 token，置零向量（或按你需要的策略）
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                vec_avg_without_prompt = _zero.clone()
                # avg_with_prompt：只对 prompt 做平均
                vec_avg_with_prompt = actual_prompt_hidden.mean(dim=0)
                # last_token：沿用占位（注意语义是 prompt[-1]，必要时可加一个标记位）
                vec_last_token = output_hidden[0]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计（G=1）[/cyan]")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")
                    console.print(f"  avg_with_prompt mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            elif len(output_states) > 0 and effective_len > 0:
                # 只取实际输出长度的隐状态
                output_states = output_states[:effective_len]
                normed = []
                for t in output_states:
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    normed.append(t)
                # 拼接得到 [effective_len, hidden_dim]
                output_hidden = torch.cat(normed, dim=0)

                # output_hidden = torch.cat(output_states, dim=0)  # [output_len, hidden_dim]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计:[/cyan]")
                    console.print(f"  实际输出状态数: {len(output_states)}")
                    console.print(f"  输出hidden shape: {output_hidden.shape}")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")

                # 1. 输出部分token平均值（不包含prompt）
                vec_avg_without_prompt = output_hidden.mean(dim=0)

                # 2. 输出部分token平均值（包含prompt）
                all_hidden = torch.cat([actual_prompt_hidden, output_hidden], dim=0)
                vec_avg_with_prompt = all_hidden.mean(dim=0)

                # 3. 最后一个token的表征
                vec_last_token = output_hidden[-1]

                if self.debug_mode:
                    console.print(f"  avg_with_prompt: mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt: mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token: mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            else:
                console.log(f"[yellow]Warn: insufficient output states for {k}, got {len(output_states)}, need {output_len}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()

        return layer_features

    # 单次前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], str, bool]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 1) 构建 prompt
        prompt = parse_input(item)

        # 2) tokenize
        input_ids, prompt_ids, _ = self._build_input_ids(prompt)

        # 确保pad_token已设置（与批量处理保持一致）
        if self.tokenizer.pad_token_id is None:
            try:
                special_tokens_dict = {"pad_token": "<|pad|>"}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except Exception:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = False
        self.current_sample_idx = 0
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            attn_mask = torch.ones_like(input_ids)
            gen_out = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                attention_mask=attn_mask,
                # cache_implementation="dynamic"
            )

            # 5) 答案生成解码
            # full_txt = self.tokenizer.decode(gen_out.sequences[0], skip_special_tokens=False)
            # prompt_txt = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
            # answer_txt = full_txt[len(prompt_txt) :].lstrip()

            # 6) 从hooks提取特征
            prompt_len = len(prompt_ids)
            total_sequence = gen_out.sequences[0]
            right_pad_len = self.calculate_right_padding_length(total_sequence)
            valid_total_len = len(total_sequence) - right_pad_len
            output_len = valid_total_len - prompt_len
            if output_len < 0:
                output_len = 0

            valid_sequence = (total_sequence[:-right_pad_len] if right_pad_len > 0 else total_sequence)
            answer_txt = self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True)

            ####### DEBUG #########
            # console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
            #######################

            # 判断是否命中最大长度限制
            hit_limit = output_len >= max_new_tokens
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            ended_with_eos = False
            if output_len > 0:
                last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                if last_gen_id in eos_ids:
                    ended_with_eos = True

            hit_limit = (hit_limit and not ended_with_eos)  # 仅当命中上限且未遇到EOS时才判不完整
            layer_features = self._extract_features_from_hooks(
                prompt_len, output_len, needed_layers, sample_idx=0, hit_limit=hit_limit
            )

        except Exception as e:
            console.print_exception(show_locals=False)
            console.print(f"[red]Single sample processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            hidden_dim = self.model.config.hidden_size
            layer_features = {}
            _dtype = next(self.model.parameters()).dtype
            _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k] = {
                    "avg_with_prompt": _zero.clone(),
                    "avg_without_prompt": _zero.clone(),
                    "last_token": _zero.clone()
                }
            answer_txt = "Failed"
            hit_limit = False
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return layer_features, answer_txt, hit_limit

    # 批量前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[List[Dict[str, Dict[str, torch.Tensor]]], List[str], List[bool]]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        self.tokenizer.padding_side = "left"

        # 1) 构建批量 prompts
        prompts = [parse_input(item) for item in items]

        # 2) 批量 tokenize（pad_token已在_build_batch_input_ids中统一设置）
        (batch_input_ids, padded_attention_mask, prompt_ids_list, _, left_pad_lens) = self._build_batch_input_ids(prompts)

        # 获取pad_id（此时pad_token_id已经在_build_batch_input_ids中设置好了）
        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = True
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            gen_out = self.model.generate(
                input_ids=batch_input_ids,
                max_new_tokens=max_new_tokens,
                attention_mask=padded_attention_mask,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                # cache_implementation="dynamic"
            )

            # 5) 批量答案解码
            batch_answers = []

            # 6) 批量提取隐状态 - 针对输出部分
            batch_layer_features = []

            hit_limit_flags = []  # 记录是否命中最大长度限制
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            for i, prompt_ids in enumerate(prompt_ids_list):
                # 获取该样本的实际序列信息（考虑左右两侧padding）
                sample_sequence = gen_out.sequences[i]

                # 计算左侧padding长度
                left_pad_len = int(left_pad_lens[i].item())

                # 提取去除左侧padding的序列
                actual_sequence = sample_sequence[left_pad_len:]

                # 计算右侧padding长度
                right_pad_len = self.calculate_right_padding_length(actual_sequence)
                if right_pad_len > 0:
                    valid_sequence = actual_sequence[:-right_pad_len]
                else:
                    valid_sequence = actual_sequence

                # 计算有效序列长度（去除左右padding）
                valid_seq_len = len(valid_sequence)

                # 获取prompt长度和输出长度
                # prompt_len = len(prompt_ids)
                prompt_len = int(padded_attention_mask[i].sum().item())
                output_len = valid_seq_len - prompt_len

                # 确保prompt长度不超过有效序列长度
                if prompt_len > valid_seq_len:
                    console.print(f"[red]错误: 样本{i}的prompt长度({prompt_len})超过有效序列长度({valid_seq_len})[/red]")
                    output_len = 0  # 强制设为0，使用零向量

                ended_with_eos = False
                if output_len > 0:
                    last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                    if last_gen_id in eos_ids:
                        ended_with_eos = True

                hit_limit_tmp = output_len >= max_new_tokens and not ended_with_eos
                hit_limit_flags.append(hit_limit_tmp)
                # 从hooks提取特征
                layer_features = self._extract_features_from_hooks(
                    prompt_len,
                    output_len,
                    needed_layers,
                    sample_idx=i,
                    hit_limit=hit_limit_tmp,
                )
                # 答案解码
                batch_answers.append(self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True))
                batch_layer_features.append(layer_features)

                ####### DEBUG #########
                console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
                #######################

        except Exception as e:
            console.print(f"[red]Batch processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            batch_answers = ["Failed"] * len(items)
            batch_layer_features = []
            hit_limit_flags = [False] * len(items)
            hidden_dim = self.model.config.hidden_size

            _dtype = next(self.model.parameters()).dtype
            _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for i in range(len(items)):
                layer_features = {}
                for k in needed_layers:
                    layer_features[k] = {
                        "avg_with_prompt": _zero.clone(),
                        "avg_without_prompt": _zero.clone(),
                        "last_token": _zero.clone(),
                    }
                batch_layer_features.append(layer_features)
            raise  # 重新抛出异常以便上层处理
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return batch_layer_features, batch_answers, hit_limit_flags

    #  把原始数据按 prompt token 长度从长到短排序，尽可能避免长短序列混搭导致大批次OOM
    def _sort_data_by_prompt_len(self, data):
        """
        输入: data(List[Dict])   输出: 排序后的 data(List[Dict])
        """

        def _build_prompt_texts(items):
            # 用你上传的 parse_input 拿最终 prompt 文本
            questions = [parse_input(it) for it in items]

            # 如果 tokenizer 支持 chat template，就按你后续 generate 的方式展开
            prompt_txts = []
            add_special = True
            has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
            if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
                add_special = False
                for q in questions:
                    messages = [{"role": "user", "content": q}]
                    try:
                        if "qwen3" in str(self.model_key).lower():
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                                enable_thinking=True,
                            )
                        else:
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                            )
                    except TypeError:
                        # 老版本 tokenizer 兼容
                        txt = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    prompt_txts.append(txt)
            else:
                # 非 chat tokenizer，直接用文本
                prompt_txts = questions

            return prompt_txts, add_special

        # 纯 CPU 分词测长度，分块以避免一次性处理太大
        LENS_CHUNK = 256
        n = len(data)
        lens = [0] * n
        for i in range(0, n, LENS_CHUNK):
            sub = data[i : i + LENS_CHUNK]
            prompt_txts, add_special = _build_prompt_texts(sub)
            enc = self.tokenizer(
                prompt_txts,
                add_special_tokens=add_special,
                truncation=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            for k, ids in enumerate(enc["input_ids"]):
                lens[i + k] = len(ids)

        # 长 → 短 排序
        order = sorted(range(n), key=lambda t: (lens[t], t), reverse=True)
        return [data[idx] for idx in order]

    def _save_batch_features(
        self,
        feats_avg_with_prompt: Dict[str, List[torch.Tensor]],
        feats_avg_without_prompt: Dict[str, List[torch.Tensor]],
        feats_last_token: Dict[str, List[torch.Tensor]],
        labels: List[int],
        ids: List[str],
        questions: List[str],
        true_answers: List[str],
        pred_answers: List[str],
        options: List[List[str]],
        parsed_answers: List[str],
        incomplete_flags: List[bool],
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
    ) -> None:
        """保存当前批次的特征"""

        def save_features(feats, output_dir, des):
            for k, vec_list in feats.items():
                if not vec_list:
                    continue
                out_path = os.path.join(output_dir, f"{k}_features.pt")

                # 如果文件已存在，加载并合并
                if os.path.exists(out_path):
                    existing_data = torch.load(out_path, map_location="cpu")

                    # Id 去重，防止在 batch 处理完，pt写入但是 id 没有及时写入导致的重复处理
                    existing_ids = existing_data["ids"]
                    existing_id_set = set(existing_ids)

                    if len(vec_list) != len(ids):
                        raise ValueError(f"{k}: vec_list({len(vec_list)}) 与 ids({len(ids)}) 长度不一致")

                    # 以当前要保存的这一路特征 vec_list 为例，先把它和同索引的 meta 列表打包
                    new_records = list(zip(ids, labels, questions, true_answers, pred_answers, options, parsed_answers, incomplete_flags, vec_list,))

                    # 过滤：仅保留还未出现过的 id
                    new_records = [r for r in new_records if r[0] not in existing_id_set]

                    if new_records:
                        n_ids, n_labels, n_qs, n_trues, n_preds, n_opts, n_pars, n_incomp, n_vecs = zip(*new_records)
                        tensor_new = torch.stack(list(n_vecs)).cpu()
                        tensor = torch.cat([existing_data["features"], tensor_new], dim=0)

                        merged_ids = existing_ids + list(n_ids)
                        merged_labels = existing_data["labels"].tolist() + list(n_labels)
                        merged_questions = existing_data["questions"] + list(n_qs)
                        merged_true_answers = existing_data["true_answers"] + list(n_trues)
                        merged_pred_answers = existing_data["pred_answers"] + list(n_preds)
                        merged_options = existing_data["options"] + list(n_opts)
                        merged_parsed_answers = existing_data["parsed_answers"] + list(n_pars)
                        merged_incomplete = existing_data["incomplete_flags"] + list(n_incomp)
                    else:
                        # 没有新增；直接复用 existing_data
                        tensor = existing_data["features"]
                        merged_ids = existing_ids
                        merged_labels = existing_data["labels"].tolist()
                        merged_questions = existing_data["questions"]
                        merged_true_answers = existing_data["true_answers"]
                        merged_pred_answers = existing_data["pred_answers"]
                        merged_options = existing_data["options"]
                        merged_parsed_answers = existing_data["parsed_answers"]
                        merged_incomplete = existing_data["incomplete_flags"]
                else:
                    # 首次写入
                    tensor = torch.stack(vec_list)
                    merged_ids = ids
                    merged_labels = labels
                    merged_questions = questions
                    merged_true_answers = true_answers
                    merged_pred_answers = pred_answers
                    merged_options = options
                    merged_parsed_answers = parsed_answers
                    merged_incomplete = incomplete_flags

                tmp = out_path + ".tmp"
                torch.save({
                    "features": tensor,
                    "labels": torch.tensor(merged_labels),
                    "ids": merged_ids,
                    "questions": merged_questions,
                    "true_answers": merged_true_answers,
                    "pred_answers": merged_pred_answers,
                    "options": merged_options,
                    "parsed_answers": merged_parsed_answers,
                    "incomplete_flags": merged_incomplete,
                }, tmp)
                os.replace(tmp, out_path)
                console.print(f"[green]保存 {k} {des} → {tensor.shape} 到 {out_path}[/green]")

        save_features(feats_avg_with_prompt, avg_with_prompt_dir, des="输出平均值(含prompt)")
        save_features(feats_avg_without_prompt, avg_without_prompt_dir, des="输出平均值(不含prompt)",)
        save_features(feats_last_token, last_token_outputdir, des="最后一个token的表征")

    def extract_dataset(
        self,
        dataset_name: str,
        dataset_path: str,
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 2048,
        batch_size: int = 4,
        resume: bool = False,
        checkpoint_root: Optional[str] = None,
    ) -> None:
        os.makedirs(avg_with_prompt_dir, exist_ok=True)
        os.makedirs(avg_without_prompt_dir, exist_ok=True)
        os.makedirs(last_token_outputdir, exist_ok=True)

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
        if self.debug_mode:
            data = data[:250]
        console.rule(f"处理数据集 → {dataset_name}")

        needed_layers: List[str] = (
            list(self.layer_map.keys()) if layer_req == "all" else [layer_req]
        )

        # 如果没有指定断点目录，则路径结构与 feats 保持一致：{output_dir}/{model_key}_avg_with_prompt/{dataset_name}/checkpoint_meta.json
        checkpoint_dir = os.path.join(os.path.normpath(checkpoint_root), self.model_key, dataset_name)

        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
        ckpt = CheckpointManager(checkpoint_path, load_existing=resume)

        if not resume:
            ckpt.reset()  # 非断点模式直接从头跑
            # 同时删除历史断点文件和特征文件
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            clear_dir(avg_with_prompt_dir)
            clear_dir(avg_without_prompt_dir)
            clear_dir(last_token_outputdir)

        # 2) 预清洗 + 统计无效 ID
        cleaned_data: List[Dict[str, Any]] = []
        bad_id_count = 0
        for item in data:
            _id = get_item_id(item)
            if _id == "":
                bad_id_count += 1
                continue
            # 统一id描述
            item["_resolved_id"] = _id
            cleaned_data.append(item)
        if bad_id_count:
            print(f"[warn] {bad_id_count} 条样本缺少 id/unique_id，已跳过。")

        ckpt.set_total(len(cleaned_data))

        # 3) 基于 checkpoint 的 processed_ids 进行过滤（在排序前）
        if resume:
            remaining = [it for it in cleaned_data if not ckpt.is_processed(it["_resolved_id"])]
        else:
            remaining = cleaned_data  # 完全从头开始

        if not remaining:
            print("[info] 没有剩余样本可处理。")
            return

        # 2. 按照 prompt token 长度从长到短排序（排序在切片之前）
        data = self._sort_data_by_prompt_len(remaining)
        console.print("[cyan]已按prompt长度排序[/cyan]")

        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")

        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]

            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")

            # 用于存储当前批次的特征和标签
            batch_feats_avg_with_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_avg_without_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_labels: List[int] = []
            batch_ids: List[str] = []
            batch_questions: List[str] = []
            batch_true_answers: List[str] = []
            batch_pred_answers: List[str] = []
            batch_options: List[List[str]] = []
            batch_parsed_answers: List[str] = []
            batch_incomplete_flags: List[bool] = []

            # 标记当前批次是否完整处理完毕
            batch_completed = False
            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers, hit_limit_flags = self.forward_batch(
                    batch_items,
                    max_new_tokens=max_new_tokens,
                    needed_layers=needed_layers
                )

                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["question"].strip()
                    # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                    # console.print(f"model_answer {start_idx + i}: {answer}")

                    batch_ids.append(item.get("_resolved_id", ""))
                    batch_questions.append(q)
                    batch_true_answers.append(item["answer"])
                    batch_options.append(item["choices"]["text"])
                    batch_pred_answers.append(answer)

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                        batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                        batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                    # incomplete 截断
                    if hit_limit_flags[i]:
                        choices_str = "".join(item["choices"]["label"])
                        answer = trim_incomplete_sentence(answer, choices_str)

                    labels = item["choices"]["label"]
                    ans_label = str(item["answer"]).strip()
                    try:
                        ans_idx = labels.index(ans_label)
                    except ValueError:
                        # 兜底：当 answer 恰好是 'A' / 'B' 这种字母时
                        ans_idx = ord(ans_label.upper()) - ord("A")
                    option_answer_text = item["choices"]["text"][ans_idx]

                    extracted_answer, is_correct = arc_challenge_parse(
                        answer,
                        item["answer"],
                        hit_limit_flags[i],
                        CHOICES[len(labels)],
                        option_answer_text,
                    )
                    if extracted_answer == "Incomplete" or hit_limit_flags[i]:
                        console.print("\n[red]Incomplete answer[/red]")

                        # 测试一下answer输出
                        # print("answer:", answer)
                        # print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                        if extracted_answer is None:
                            print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct)
                        else:
                            if len(extracted_answer) < 500:
                                print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct,)
                            else:
                                print("pred_answer:", "Incomplete", "true answer:", item["answer"], "is_correct:", is_correct)
                    else:
                        print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct)

                    batch_labels.append(int(is_correct))
                    batch_parsed_answers.append(extracted_answer)
                    batch_incomplete_flags.append(hit_limit_flags[i])

                # 标记批次完成
                batch_completed = True

            except Exception as e:
                console.print(f"[red]批次 {batch_idx + 1} 处理失败: {escape(str(e))}[/red]")
                console.print("[yellow]回退到单样本处理模式[/yellow]")

                # 仅在OOM时尝试清理
                if "out of memory" in str(e).lower():
                    gc.collect()
                    torch.cuda.empty_cache()

                # 清空当前批次已收集的数据，避免部分数据重复保存
                if batch_ids:
                    console.print("[yellow]清空当前批次已收集的数据，避免部分数据重复保存[/yellow]")
                    batch_feats_avg_with_prompt = {k: [] for k in needed_layers}
                    batch_feats_avg_without_prompt = {k: [] for k in needed_layers}
                    batch_feats_last_token = {k: [] for k in needed_layers}
                    batch_labels.clear()
                    batch_ids.clear()
                    batch_questions.clear()
                    batch_true_answers.clear()
                    batch_pred_answers.clear()
                    batch_options.clear()
                    batch_parsed_answers.clear()
                    batch_incomplete_flags.clear()

                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer, hit_limit = self.forward_once(
                            item,
                            max_new_tokens=max_new_tokens,
                            needed_layers=needed_layers
                        )

                        q = item["question"].strip()
                        # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                        # console.print(f"model_answer {start_idx + i}: {answer}")

                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(q)
                        batch_true_answers.append(item["answer"])
                        batch_options.append(item["choices"]["text"])
                        batch_pred_answers.append(answer)

                        # 添加特征
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                            batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                            batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                        # 被截断的文本需要去掉最后一句不完整的话
                        if hit_limit:
                            choices_str = "".join(item["choices"]["label"])
                            answer = trim_incomplete_sentence(answer, choices_str)

                        # 使用正确的选项文本
                        labels = item["choices"]["label"]
                        ans_label = str(item["answer"]).strip()
                        try:
                            ans_idx = labels.index(ans_label)
                        except ValueError:
                            ans_idx = ord(ans_label.upper()) - ord("A")
                        option_answer_text = item["choices"]["text"][ans_idx]

                        extracted_answer, is_correct = arc_challenge_parse(
                            answer,
                            item["answer"],
                            hit_limit,
                            CHOICES[len(labels)],
                            option_answer_text,
                        )

                        if extracted_answer == "Incomplete" or hit_limit:
                            console.print("\n[red]Incomplete answer[/red]")

                            if extracted_answer is None:
                                print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct)
                            else:
                                if len(extracted_answer) < 500:
                                    print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct)
                                else:
                                    print("pred_answer:", "Incomplete","true answer:", item["answer"],"is_correct:", is_correct)
                        else:
                            print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct)

                        batch_parsed_answers.append(extracted_answer)
                        # console.print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                        batch_labels.append(int(is_correct))
                        batch_incomplete_flags.append(hit_limit)

                    except Exception as single_e:
                        console.print(f"[red]样本 {start_idx + i} 处理失败: {escape(str(single_e))}[/red]")
                        # 添加空特征以保持索引一致
                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(item["question"].strip())
                        batch_true_answers.append(item["answer"])
                        batch_options.append(item["choices"]["text"])
                        batch_pred_answers.append("Failed")
                        batch_parsed_answers.append("Failed")
                        batch_labels.append(0)
                        batch_incomplete_flags.append(False)
                        # 添加零特征
                        hidden_dim = self.model.config.hidden_size
                        _dtype = next(self.model.parameters()).dtype
                        _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(_zero.clone())
                            batch_feats_avg_without_prompt[k].append(_zero.clone())
                            batch_feats_last_token[k].append(_zero.clone())

                # 单样本回退模式下，整个batch处理完成才标记完成
                batch_completed = True

            # 只有当批次完整处理完毕时才保存
            if batch_completed:
                # 保存当前批次的特征
                self._save_batch_features(
                    batch_feats_avg_with_prompt,
                    batch_feats_avg_without_prompt,
                    batch_feats_last_token,
                    batch_labels,
                    batch_ids,
                    batch_questions,
                    batch_true_answers,
                    batch_pred_answers,
                    batch_options,
                    batch_parsed_answers,
                    batch_incomplete_flags,
                    avg_with_prompt_dir,
                    avg_without_prompt_dir,
                    last_token_outputdir,
                )

                # 写入断点（即使 resume=False 也会持续写，确保中断可恢复）
                ckpt.mark_batch_processed(batch_ids)
        print(f"[done] 共 {len(remaining)} 条新样本完成。累计完成 {len(ckpt.processed_ids)}/{ckpt.total_samples}.")


def load_datasets(dataset: List[str], data_dir: str) -> Dict[str, str]:
    paths: Dict[str, str] = {}
    for d in dataset:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}[/red]")
    return paths


def main() -> None:
    p = argparse.ArgumentParser("Hidden feature extractor with checkpoint")
    p.add_argument("--models_root", default=os.path.join(PROJECT_ROOT, "models"), help="本地模型根目录（用于拼接 模型名→路径）",)
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT, "new_benchmark", "arc_challenge"),)
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT, "feats/arc_challenge/arc_challenge_feats"),)
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="all")
    p.add_argument("--dataset", default=None)
    p.add_argument("--use_few_shot", action="store_true")
    p.add_argument("--layer_type", default="all", choices=["quarter", "middle", "three_quarters", "last", "second_last", "all"])
    p.add_argument("--batch_size", type=int, default=None, help="批量处理的batch size")
    p.add_argument("--resume", action="store_true", help="是否从断点恢复")
    p.add_argument("--checkpoint_root", type=str, default=os.path.join(PROJECT_ROOT, "checkpoints/arc_challenge/arc_challenge_feats"), help="checkpoint存储目录")
    # 自动化 batch 估计参数
    p.add_argument("--auto_batch", choices=["off", "p99", "max"], default="max", help="按分片自动估计 batch（p99 或 max）")
    p.add_argument("--gen_tokens", type=int, default=2048, help="估算时的最大生成长度")
    p.add_argument("--n_gpus", type=int, default=1, help="参与推理的 GPU 数")
    p.add_argument("--per_gpu_vram_gib", type=float, default=80.0, help="每张卡可用显存 GiB，用于估算")
    p.add_argument("--pct", type=float, default=0.99, help="分位数（默认 0.99，用于 pXX 计算）")
    args = p.parse_args()

    global USE_FEW_SHOT
    USE_FEW_SHOT = args.use_few_shot

    os.makedirs(args.models_root, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    if args.dataset is None:
        args.dataset = ["arc_challenge_train", "arc_challenge_val", "arc_challenge_test",]
    if isinstance(args.dataset, str):
        args.dataset = [args.dataset]
    datasets = load_datasets(args.dataset, args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在[/red]")
        return

    if args.checkpoint_root is None:
        checkpoint_root = os.path.join(args.output_dir, "checkpoints")
    else:
        checkpoint_root = args.checkpoint_root
    os.makedirs(checkpoint_root, exist_ok=True)

    # models = MODELS.keys() if args.model == "all" else [args.model]
    if args.model == "all":
        models = []
        for d in sorted(os.listdir(args.models_root)):
            mp = os.path.join(args.models_root, d)
            if os.path.isdir(mp) and (
                os.path.exists(os.path.join(mp, "config.json"))
                or os.path.exists(os.path.join(mp, "tokenizer.json"))
            ):
                models.append(d)
    else:
        models = [m.strip() for m in str(args.model).split(",") if m.strip()]

    for mk in models:
        model_path = os.path.join(args.models_root, mk)

        if not os.path.exists(model_path):
            console.print(f"[red] {mk} 模型目录不存在：{model_path}[/red]")
            continue

        console.rule(f"📦  Processing model — {mk}")
        extractor = HiddenExtractor(mk, model_path, args.device)

        for dname, dpath in datasets.items():
            avg_with_prompt_dir = os.path.join(args.output_dir, mk + "_avg_with_prompt", dname)
            avg_without_prompt_dir = os.path.join(args.output_dir, mk + "_avg_without_prompt", dname)
            last_token_out_dir = os.path.join(args.output_dir, mk + "_last_token", dname)

            # 自动估算 batch
            if args.batch_size is None:
                if args.auto_batch == "off":
                    console.print("[yellow]警告: 未指定 --batch_size，且 --auto_batch=off，无法执行，请重新设置 batch[/yellow]")
                    return
            else:
                effective_bsz = max(1, int(args.batch_size))  # 默认使用命令行的 batch_size，防止 auto_batch 为 off 时参数未定义报错

            # 设置最大生成长度
            mk_lower = mk.lower()
            if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
                gen_tok = 8192
            else:
                gen_tok = int(args.gen_tokens)

            if args.auto_batch != "off":
                try:
                    # 计算该分片的 (mx, pxx) —— 按原始功能脚本的单分片接口来做
                    # 注意：这里直接把 dname 作为 dataset_with_split 传入，例如 "math_test"
                    mx, pxx = PTS_POST.token_calculation(
                        model_name=mk,
                        dataset_with_split=dname,
                        pct=args.pct,  # 默认 0.99
                    )

                    bsz_p99, bsz_mx = VRAM_POST.batch_estimation(
                        mx=mx,
                        p99=pxx,
                        model_name=mk,
                        gen_tokens=gen_tok,
                        n_gpus=args.n_gpus,
                        per_gpu_vram_gib=args.per_gpu_vram_gib,
                    )

                    auto_bsz = bsz_p99 if args.auto_batch == "p99" else bsz_mx

                    # 若用户也传了 --batch_size，取更小的，稳妥一些
                    effective_bsz = max(1, int(auto_bsz))
                    if args.batch_size is not None:
                        effective_bsz = min(effective_bsz, int(args.batch_size))

                    # 打印选取结果
                    if args.auto_batch == "p99":
                        console.print(f"[cyan]AUTO-BATCH(p99)[/cyan] {mk}/{dname}: p{int(args.pct*100)}={pxx} → use batch_size={effective_bsz}")
                    else:
                        console.print(f"[cyan]AUTO-BATCH(max)[/cyan] {mk}/{dname}: mx={mx} → use batch_size={effective_bsz}")

                except Exception as e:
                    console.print(f"[yellow]AUTO-BATCH 估计失败 {mk}/{dname}：{e}；回退 batch_size={args.batch_size}[/yellow]")
                    if args.batch_size is not None:
                        effective_bsz = max(1, int(args.batch_size))
                    else:
                        # 兜底到 1，或你可以兜底到一个更保守的值
                        effective_bsz = 1
                    console.print(f"[yellow]使用回退 batch_size={effective_bsz}[/yellow]")

            extractor.extract_dataset(
                dname,
                dpath,
                avg_with_prompt_dir,
                avg_without_prompt_dir,
                last_token_out_dir,
                layer_req=args.layer_type,
                max_new_tokens=gen_tok,
                batch_size=effective_bsz,
                resume=args.resume,  # 传递resume参数
                checkpoint_root=checkpoint_root,
            )

    console.rule("程序结束！")
    # logf.close()


if __name__ == "__main__":
    main()
