from __future__ import annotations
import argparse
import json
import os
import shutil
import importlib.util
from typing import Dict, List, Optional, Tuple, Iterable, Any, Set

import torch
import gc
from rich.console import Console
from rich.progress import track
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging
import re, unicodedata
import sys
import random
from rich.markup import escape
from collections import Counter

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)


from utils.mmlu_pro_examples import MMLU_PRO_FEWSHOT

logging.set_verbosity_error() # 仅显示ERROR信息，忽略WARNING
console = Console()

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

USE_FEW_SHOT = False
torch.manual_seed(42)

############# 添加日志记录 ##############################
# logf = open("mmlu_pro_forward_hook_logs.txt", "w")
def log_cmp(extracted_answer, item_answer, is_correct):
    """同时打印到控制台并写入txt，自动换行与flush"""
    # 压平答案里的换行，避免一条记录被拆成多行
    ea = str(extracted_answer).replace("\n", "\\n")
    ta = str(item_answer).replace("\n", "\\n")
    line = f"[pred_answer]: {ea} true answer: {ta} is_correct: {is_correct}"
    # print(line)
    logf.write(line + "\n")

choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]
_LETTER = r"[A-J]"   # MMLU-pro 固定 A–J

random.seed(42)

# 正则匹配模式
AFTER = r'(?=$|[\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜·\-–—~〜～*`]|项)'
SEP_CHARS = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～/\\\*`'
RE_MAIN_PAREN = re.compile(rf'(?i)the\s+answer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}')
RE_VARIANTS = [
    # 1) 英文强锚点
    re.compile(rf'(?i)\b(?:final|correct)\s+answer\s*(?:is\s*[:：]?)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?{AFTER}'),  # 允许没有 is 与冒号
    re.compile(rf'(?i)\bthe\s+answer\s+is\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\banswer\s+is\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\banswer\s*[:：]\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),               # 冒号写法
    re.compile(rf'(?i)\b(?:correct|final)\s*choice\s*(?:is|:)?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)\bthe\s+answer\s+is\s*[:：]?\s*(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),

    # 2) 中文常见
    re.compile(rf'(?i)(?:正确答案|答案)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)(?:故选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?'+AFTER),
    re.compile(rf'(?i)(?:选择(?:答案)?|选(?:择)?)\s*(?:为|是)?\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?' + AFTER),
    re.compile(rf'(?i)(?:选项|选)\s*[:：]?\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*项?'+AFTER),

    # 英文动作动词（避免误命中枚举行）
    re.compile(rf'(?i)\b(?:choose|chosen|pick|picked|select|selected)\s+(?:option\s*)?[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?' + AFTER),

    # 3) MarkDown
    re.compile(rf'(?i)\b(?:final|correct)?\s*answer\s*[:：]?\s*\**\s*[\[\{{\(\（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\**'+AFTER),
    # re.compile(r'(?i)\*\*\s*({_LETTER})\s*\*\*' + AFTER),

    # 4) LaTeX
    # re.compile(r'(?i)\\boxed\{\s*\(?\s*({_LETTER})\s*\)?\s*\}'),
    # re.compile(r'(?i)\\boxed\{\s*(?:\\(?:textbf|mathbf|mathrm|mathsf)\{\s*({_LETTER})\s*\}|({_LETTER}))\s*\}'),
    # re.compile(r'(?i)\\\(\s*({_LETTER})\s*\\\)'),                                # 允许括号包字母
    # （可选再加：\(\s*({_LETTER})\s*\)，用于 \(D\) 这种）
]

# 整行锚定，一般不会错杀
RE_LASTLINE_FORMAT_FALLBACKS = [
    # **A** / **b**
    re.compile(rf'^\s*\*\*\s*({_LETTER})\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # **(A)**
    re.compile(rf'^\s*\*\*\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*\*\*\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # \boxed{A} / \boxed{\textbf{A}} / \boxed{\mathrm{a}} …
    re.compile(
        rf'^\s*\\boxed\{{\s*(?:\\(?:textbf|mathbf|mathrm|mathsf|bfseries|boldsymbol)\{{\s*({_LETTER})\s*\}}|({_LETTER}))\s*\}}\s*[\.\!\?\u3002\uFF1F]*\s*$',
        re.I
    ),
    # \(A\)
    re.compile(rf'^\s*\\\(\s*({_LETTER})\s*\\\)\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
    # (A) / [A] / {A} 以及全角括号
    re.compile(rf'^\s*[\(\[\{{（【]\s*({_LETTER})\s*[\)\]\}}）】]\s*[\.\!\?\u3002\uFF1F]*\s*$', re.I),
]

### 增加 incomplete 截断 ###
_SENT_END_RE = re.compile(r'(?:[。！？!?]|\.{3,}|…|\.(?=(?:\s|\u3000|$|[”"\'’）\)\]】》])))[\s\u3000]*[”"\'’）\)\]】》]*[\s\u3000]*')
ANSWER_PHRASES = [
    # === 英文通用（收敛） ===
    r"\b(?:so|thus|therefore)\b.*?\b(?:final\s+|correct\s+)?(?:answer|result|solution|choice)\s*(?:is|are|:|：)",  # 有因果/收束词
    r"\b(?:the\s+)?(?:final\s+|correct\s+)?answer\s*(?:is|are|:|：)",   # the/final/correct answer ...
    r"\banswer\s*(?:is|:|：)",                                          # answer is / answer:
    r"\b(?:final|correct)\s+choice\s*(?:is|:|：)",                      # final/correct choice ...
    r"\bthe\s+answer\s+is\s*(?:option\s*)?",                            # the answer is (option) ...
    r"\b(?:result|solution)\s*[:：]",                                   # 只在冒号形式放宽 result/solution

    # === 中文通用（无需前缀空格） ===
    r"(?:因此|所以|综上(?:所述)?|故|可知)\s*(?:最终\s*)?(?:答案|结果|选项|选择)\s*(?:是|为|:|：)",
    r"(?:最终\s*)?(?:答案|结果)\s*(?:是|为|:|：)",

    # === 简洁形式（无动词） ===
    r"(?:answer|答案)\s*[:：]",

    # === 选择题强语义结尾（收紧） ===
    # 仅保留“故选/应选”，或严格限制为 A–J/全角Ａ–Ｊ/数字/中文数字 后缀
    r"(?:故\s*选|应\s*选)\s*[A-JＡ-Ｊ](?:\s*项|(?:\s*选项)?)?",
    r"(?:故\s*选|应\s*选)\s*(?:[0-9]+|[一二三四五六七八九十])(?:\s*项|(?:\s*选项)?)?",
]
ANSWER_PHRASES_RE = [re.compile(ANSWER_PHRASES[0], re.I | re.S)] + \
                    [re.compile(p, re.I) for p in ANSWER_PHRASES[1:]]

# 如果尾巴有完整答案，则不截断（理论上发生几率较小）
def _tail_has_complete_answer(tail: str) -> bool:
    if RE_MAIN_PAREN.search(tail):  # 严格主形态
        return True
    for pat in RE_VARIANTS:         # 变体里只要有非空捕获
        m = pat.search(tail)
        if m and next((g for g in m.groups() if g), None):
            return True
    return False

def _contains_answer_phrase(text: str) -> bool:
    return any(p.search(text) for p in ANSWER_PHRASES_RE)

# 去掉不完整的结尾句子
def trim_incomplete_sentence(ans: str) -> str:
    if not ans:
        return ans
    matches = list(_SENT_END_RE.finditer(ans))
    if not matches:
        return ans
    last_end = matches[-1].end()

    # print(ans[:last_end][:50], ans[last_end:])

    # 只有在最后一句不完整的话中包含可能会被匹配的候选项并且前面的内容中也出现了候选项时才截断
    if _contains_answer_phrase(ans[last_end:]) and _contains_answer_phrase(ans[:last_end]):
        if _tail_has_complete_answer(ans[last_end:]):
            return ans
        return ans[:last_end].strip()
    return ans

# 暂时不考虑 few-shot
def format_prompt(example):
    prompt = "Question:\n"
    question = example["question"]
    options = example["options"]
    prompt += question + "\n"
    prompt += "Options:\n"
    for i, opt in enumerate(options):
        prompt += "{}. {}\n".format(choices[i], opt)
    return prompt

# question, options, answer, answer_index, category
def parse_input(item):
    category = item["category"]
    prompt = f"The following are multiple choice questions (with answers) about {category}. Think step by step and then finish your answer with \"the answer is (X)\" where X is one of Options(e.g. one of ABCDEFGHIJ).\n"
    if USE_FEW_SHOT:
        examples = MMLU_PRO_FEWSHOT[category][:MMLU_PRO_FEWSHOT["few-shot-num"]]
        for example in examples:
            prompt += format_prompt(example)
            prompt += "\n"
    context = prompt + format_prompt(item)
    return context

# 将字符串转换为规范形式，去除空白，压缩空白，保留换行
def _normalize(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize('NFKC', s)
    s = s.replace("（", "(").replace("）", ")")
    for z in ("\u200b", "\xa0", "\u202f", "\ufeff", "\u200d", "\u2060", "\t"):
        s = s.replace(z, " ")
    # 仅压缩空白但保留换行：避免破坏“最后一行”逻辑
    s = re.sub(r'[^\S\r\n]+', ' ', s)
    return s

def extract_answer(text: str, option_answer_text: str):
    """首选匹配：the answer is (X) —— 不区分大小写，取最后一次出现。"""
    if not text:
        return None
    t = _normalize(text)

    # 展开常见文本/样式宏为纯文本，防止误杀 \text{A} / \boxed{A} 等
    t = re.sub(r'\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}', r'\1', t)
    matches = list(RE_MAIN_PAREN.finditer(t))
    if matches:
        return matches[-1].group(1).upper()
    return extract_again(t, option_answer_text)

# 变体匹配，优先选择强锚点
def extract_again(text: str, option_answer_text: str):
    best_end = -1
    best_val = None
    for pat in RE_VARIANTS:
        for m in pat.finditer(text):
            # 兼容多捕获组，并按顺序（锚点强弱）取第一个非空的组内容
            g = next((x for x in m.groups() if x), None)
            if not g:
                continue
            if m.end() > best_end:
                best_end, best_val = m.end(), g.upper()
    return best_val if best_val is not None else extract_final(text, option_answer_text)

def _strip_latex_noise_oneline(s: str) -> str:
    # 去掉行内/陈列数学环境定界符
    s = re.sub(r'\\\[|\\\]|\\\(|\\\)|\$\$?', ' ', s)
    # 去掉常见 \frac / \dfrac / \tfrac 形式
    s = re.sub(r'\\[dt]?frac\s*\{[^{}]*\}\s*\{[^{}]*\}', ' ', s)
    # 去掉通用 \cmd{...}（保留 \boxed{...} 等已在上面兜底过的）
    s = re.sub(r'\\(?!boxed\b)[A-Za-z]+\s*\{([^{}]*)\}', r' \1 ', s)
    return s

# 兜底：看最后三行
def extract_final(text: str, option_answer_text: str):
    """
    兜底：看最后三行
    先试 the answer is (X)，再试纯字母行 '(D)' / 'D'，最后取最后一个 A–J。
    """
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    if not lines:
        return None

    # 检查末尾3行（如果长度允许的话）
    for idx, ln in enumerate(reversed(lines[max(-len(lines),-3):])):
        last = _normalize(ln)

        # 末行快速重试若干锚点（含无括号）
        # 再跑一遍强/中锚点
        for pat in RE_VARIANTS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        for pat in RE_LASTLINE_FORMAT_FALLBACKS:
            m = pat.search(last)
            if m:
                g = next((x for x in m.groups() if x), None)
                if g:
                    return g.upper()

        if option_answer_text:
            cand = extract_latex_non_letter(last)
            if cand:
                if _normalize_for_compare(cand) == _normalize_for_compare(option_answer_text):
                    # 提取到的“非字母选项内容”与正确选项文本精确匹配
                    return cand

        last_sanit = _strip_latex_noise_oneline(last)

        # 整行只包含一个字母（可带括号和标点）
        m = re.match(rf'^\s*[\(\[\{{（【]?\s*({_LETTER})\s*[\)\]\}}）】]?\s*[\.\!\?\u3002\uFF1F]*\s*$', last_sanit, re.I)
        if m:
            return m.group(1).upper()

        # 在行内查找被分隔符包围的独立字母，取最后一个
        SEP_CHARS_RIGHT = SEP_CHARS
        SEP_CHARS_LEFT  = r'\s,.;:!?，。！？：；、()\[\]{}（）【】"“”\'《》<>\|｜=·\-–—~〜～*`'  # 去掉了 / 和 \
        pattern = rf'(?i)(?:^|[{SEP_CHARS_LEFT}])\s*[\(\[\{{（【]?\b({_LETTER})\b\s*[\)\]\}}）】]?(?!\s*[/\\=^\(])(?=$|[{SEP_CHARS_RIGHT}]|项)'
        matches = list(re.finditer(pattern, last_sanit, re.I))
        if matches:
            return matches[-1].group(1).upper()

    return None

# 标准化字符串，去除空白，压缩空白，保留换行，方便答案比较
def _normalize_for_compare(s: str) -> str:
    s = s or ""
    s = unicodedata.normalize("NFKC", s)
    s = re.sub(r"\s+", " ", s) # 压缩连续空白符
    return s.strip()

def _latex_clean_payload(s: str) -> str:
    if s is None:
        return ""
    cur = s

    # —— 分数直接删掉（含 \frac/\tfrac/\dfrac 与 纯文本 a/b 形式）——
    # 1) 去掉 \frac{...}{...} / \tfrac / \dfrac
    cur = re.sub(r"\\[td]?frac\s*\{\s*[^{}]*?\s*\}\s*\{\s*[^{}]*?\s*\}", " ", cur)
    # 2) 去掉纯文本分数 a/b（含可选正负号与空格）
    cur = re.sub(r"[-+−]?\s*\d+\s*/\s*\d+", " ", cur)

    # 统一 Unicode 负号
    cur = cur.replace("\u2212", "-")

    # 递归剥外层常见文本/样式宏壳
    prev = None
    while prev != cur:
        prev = cur
        cur = re.sub(
            r"\\(?:text|mathrm|mathbf|mathsf|textbf|operatorname|textrm|rm)\s*\{([^{}]*)\}",
            r"\1", cur
        )

    # 去掉数学定界符与剩余裸命令
    cur = re.sub(r"(\\\[|\\\]|\$\$|\\\(|\\\))", " ", cur)
    cur = re.sub(r"\\[A-Za-z]+", " ", cur)   # \left \right 等
    cur = cur.replace("{", " ").replace("}", " ")

    # 规整空白
    cur = _normalize_for_compare(cur)

    # 去掉引导语
    cur = re.sub(r"(?i)\b(?:the\s+answer\s+is|答案(?:为|是)?)\b\s*[:：]?\s*", "", cur).strip()

    # 若整体仅被一对括号包裹，剥掉
    if re.fullmatch(r"[\(\[\{（【〔〈《]\s*[^()\[\]{}（）【】〔〕〈〉《》]*\s*[\)\]\}）】〕〉》]", cur):
        cur = cur[1:-1].strip()
    return cur

# 抽“非字母选项内容”。采用“后出现优先”：候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
def extract_latex_non_letter(text: str) -> str | None:
    """
    抽“非字母选项内容”。采用“后出现优先”：
    候选收集到列表后，从后往前清洗与判定；命中即返回，其余候选不会传播。
    """
    if not text:
        return None

    t = text
    candidates: list[str] = []

    # 1) \boxed{...}（可能出现多次，先收集，后出现优先）
    last_boxed = None
    for m in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", t):
        last_boxed = m.group(1)
    if last_boxed:
        candidates.append(last_boxed)

    # 2) 常见文本/样式宏
    last_text_like = None
    for macro in ("text", "textbf", "mathrm", "mathbf", "mathsf", "operatorname", "textrm", "rm"):
        for m in re.finditer(rf"\\{macro}\s*\{{\s*([^{{}}]+?)\s*\}}", t):
            last_text_like = m.group(1)
    if last_text_like:
        candidates.append(last_text_like)

    # 3) “the answer is ...” + 数字（非分数；分数已被清洗逻辑去掉）
    m_num = re.search(r"(?i)the\s+answer\s+is\s*[:：]?\s*\(?\s*([-+−]?\d+(?:\.\d+)?)", t)
    if m_num:
        candidates.append(m_num.group(1))

    # 4) 数学环境 \( ... \)、\[ ... \]、$$ ... $$
    for m in re.finditer(r"(?:\\\(|\\\[|\$\$)(.*?)(?:\\\)|\\\]|\$\$)", t, re.S):
        inner = m.group(1)
        # math 内部 \boxed{...}
        for m2 in re.finditer(r"\\boxed\s*\{\s*([^{}]+?)\s*\}", inner):
            candidates.append(m2.group(1))
        # 其次取最后出现的“数字”（非分数）
        nums = re.findall(r"[-+−]?\d+(?:\.\d+)?", inner)
        if nums:
            candidates.append(nums[-1])

    # 过滤：单个 A–J 交给字母流程
    try:
        LETTER = _LETTER  # 如果你的项目里已有
    except NameError:
        LETTER = r"[A-J]"

    for raw in reversed(candidates):  # “后出现优先”
        s = _latex_clean_payload(raw)
        if not s:
            continue
        if re.fullmatch(rf"\(?\s*{LETTER}\s*\)?", s, re.I):
            continue
        return s  # 命中非字母选项
    return None

# 去除字符串末尾的标点符号和空白字符
def _strip_trailing_punct(s: str) -> str:
    # 去掉常见句尾标点（英文/中文）
    return re.sub(r'[\.。!,，；;：:\s]+$', '', s or '').strip()

def mmlu_parse(pred, true, hit_limit, option_answer_text):
    extracted = extract_answer(pred, option_answer_text)
    if not extracted or extracted.strip() == "":
        if hit_limit is True:
            return "Incomplete", False
        return None, False

    return extracted, true.strip().upper() == extracted.upper() or _normalize_for_compare(_strip_trailing_punct(extracted.lower())) == _normalize_for_compare(_strip_trailing_punct((option_answer_text or "").lower()))

# 清空当前目录下的所有文件和子目录
def clear_dir(dir_path: str):
    if not os.path.isdir(dir_path):
        return
    for filename in os.listdir(dir_path):
        file_path = os.path.join(dir_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)       # 删除文件或符号链接
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)   # 删除子目录
        except Exception as e:
            print(f"删除失败: {file_path}, 原因: {e}")

def get_item_id(item: Dict[str, Any]) -> str:
    val = item.get("unique_id", None) # MATH 用 unique_id
    if val is None or val == "":
        val = item.get("question_id") # MMLU_PRO 用 question_id
        if val is None or val == "":
            val = item.get("id") # BBH 用 id
    if val is None:
        return ""
    return str(val).strip()  # "0" 也会被当作有效 ID

# 自动化 batch 估计
def _load_module(module_path: str, module_name: str):
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot import {module_name} from {module_path}")
    mod = importlib.util.module_from_spec(spec)
    import sys as _sys
    _sys.modules[module_name] = mod
    spec.loader.exec_module(mod)
    return mod

PTS_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "prompt_token_states_post.py")
VRAM_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "vram_estimate_post.py")
PTS_POST = _load_module(PTS_PATH, "prompt_token_states_post")
VRAM_POST = _load_module(VRAM_PATH, "vram_estimate_post")
# 断言模块是否加载成功
assert hasattr(PTS_POST, "token_calculation"), "prompt_token_states_post 模块缺少 token_calculation"
assert hasattr(VRAM_POST, "batch_estimation"), "vram_estimate_post 模块缺少 batch_estimation"

# eos 规范化
def _normalize_to_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]

class CheckpointManager:
    """
    将进度以“已处理样本 ID 集合”的方式持久化。
    load_existing=False 时，会忽略并覆盖既有 checkpoint，实现“从头跑但仍写断点”的模式。
    """
    def __init__(self, path: str, load_existing: bool = True):
        self.path = path
        self.processed_ids: Set[str] = set()
        self.completed: bool = False
        self.total_samples: int = 0
        if load_existing:
            self._load()
        else:
            # 显式选择不加载旧 ckpt：如有旧文件，将在第一次 _atomic_save 时被覆盖
            pass

    def _load(self) -> None:
        if os.path.exists(self.path):
            try:
                with open(self.path, "r", encoding="utf-8") as f:
                    data = json.load(f)
                self.processed_ids = set(map(str, data.get("processed_ids", [])))
                self.completed = bool(data.get("completed", False))
                self.total_samples = int(data.get("total_samples", 0))
            except Exception as e:
                print(f"[ckpt] 读取失败，忽略旧文件：{self.path} ({e})")

    def reset(self) -> None:
        """清空内存态的进度；下次 _atomic_save 将覆盖磁盘文件。"""
        self.processed_ids.clear()
        self.completed = False
        # total_samples 保留，后续 set_total 会更新
        # 不立即落盘，避免误覆盖，等到首次写入再覆盖

    def set_total(self, n: int) -> None:
        if self.total_samples != n:
            self.total_samples = int(n)
            self._atomic_save()

    def mark_batch_processed(self, ids: Iterable[str]) -> None:
        for _id in ids:
            sid = str(_id).strip() if _id is not None else ""
            if sid:
                self.processed_ids.add(sid)
        if self.total_samples > 0 and len(self.processed_ids) >= self.total_samples:
            self.completed = True
        self._atomic_save()

    def is_processed(self, _id: Any) -> bool:
        sid = str(_id).strip() if _id is not None else ""
        return bool(sid) and (sid in self.processed_ids)

    def _atomic_save(self) -> None:
        tmp = self.path + ".tmp"
        data = {
            "processed_ids": sorted(self.processed_ids),
            "completed": self.completed,
            "total_samples": self.total_samples,
        }
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(tmp, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False)
        os.replace(tmp, self.path)


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        try:
            attn_implementation = "flash_attention_2"
            if "gemma" in model_path:
                attn_implementation = "sdpa"
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                # use_flash_attn=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation=attn_implementation
            )
        except Exception as e:
            console.print(f"[yellow]Warning: Failed to load model with auto device map and flash-attention: {escape(str(e))}[/yellow]")
           # 防止某些模型不支持flash-attention
            try:
                # 防止某些模型不支持flash-attention
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )
            except Exception as e2:
                console.print(f"[yellow]Warning: Failed to load model with auto device map: {escape(str(e2))}[/yellow]")
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map=None,             # 单设备
                ).to(self.device)

        self.model.eval()
        self.layer_map = self._build_layer_index()
        self.calculate_right_padding_length = self._calculate_right_padding_length

        # hook相关属性
        self.hooks = []
        self.captured_hidden_states = {}
        self.is_batch_mode = False
        self.current_sample_idx = 0

        # 调试属性
        self.debug_mode = False
        self.debug_info = {}

    def _build_layer_index(self) -> Dict[str, int]:
        n_layers = self.model.config.num_hidden_layers
        mapping = {
            "middle": n_layers // 2,
            "last": n_layers - 1,
            "second_last": n_layers - 2,
        }
        console.print(
            f"总层数: {n_layers} → middle={mapping['middle']}, "
            f"second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping

    def _calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0
        pad_id = self.tokenizer.pad_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len


    # 钩子注册
    def _register_hooks(self, needed_layers: List[str]) -> None:
        """注册forward hooks来捕获指定层的隐状态"""
        self._clear_hooks()  # 先清除之前的hooks
        self.captured_hidden_states = {k: [] for k in needed_layers}

        def create_hook(layer_name: str):
            def hook_fn(module, input, output):

                if isinstance(output, tuple):
                    hidden_state = output[0]  # 通常隐状态是第一个元素
                else:
                    hidden_state = output

                # 复制到CPU后立即删除GPU引用，但使用更安全的方式
                hidden_state_cpu = hidden_state.detach().to('cpu')

                # 显式删除GPU引用以节省显存，但要确保Flash Attention已经完成计算
                if hidden_state.is_cuda:
                    del hidden_state

                # if self.debug_mode:
                #     console.print(f"[cyan]{layer_name} hidden state shape: {hidden_state_cpu.shape}[/cyan]")

                # 存储隐藏状态 - 不需要步骤计数，按顺序存储即可
                self.captured_hidden_states[layer_name].append(hidden_state_cpu)

            return hook_fn

        # 根据模型架构获取正确的层
        if hasattr(self.model, 'model'):  # 如 LlamaForCausalLM
            layers = self.model.model.layers
        elif hasattr(self.model, 'transformer'):  # 如 GPT类模型
            layers = self.model.transformer.h
        else:
            # 尝试其他可能的属性名
            for attr in ['layers', 'h', 'decoder_layers']:
                if hasattr(self.model, attr):
                    layers = getattr(self.model, attr)
                    break
            else:
                raise AttributeError("无法找到模型的层结构")

        # 修复问题1：直接根据layer_name计算索引，避免重复计算
        for layer_name in needed_layers:
            if layer_name == "last":
                layer_idx = len(layers) - 1
            elif layer_name == "second_last":
                layer_idx = len(layers) - 2
            elif layer_name == "middle":
                layer_idx = len(layers) // 2
            else:
                # 如果是其他层名，可以从layer_map获取
                layer_idx = self.layer_map.get(layer_name, 0)

            if 0 <= layer_idx < len(layers):
                hook = layers[layer_idx].register_forward_hook(create_hook(layer_name))
                self.hooks.append(hook)
                # console.print(f"注册hook for {layer_name} (layer {layer_idx})")

    def _clear_hooks(self) -> None:
        """清除所有hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        # 注意：这里不清除captured_hidden_states，因为后续还需要使用
        # self.captured_hidden_states.clear()

    # 方便统一处理chat/base模型
    def _build_input_ids(self, question: str) -> Tuple[torch.Tensor, list[int], str]:
        """
        统一构建 input_ids, 返回:
        - input_ids  : (1, seq_len) tensor on self.device
        - prompt_ids : list[int] 同一内容, 用于后续定位
        - prompt_txt : str, decode 后的完整 prompt 文本
        """
        # chat / instruct 模型：tokenizer 支持 chat template
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            prompt_struct = [{"role": "user", "content": question}]
            if "qwen3" in self.model_key.lower():
                try:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=True,
                    )
                except:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
            else:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
        else:
            batch = self.tokenizer(
                question,
                add_special_tokens=True,      # 先让tokenizer自己加他心目中的特殊符号
                return_tensors="pt",
                truncation=False,
            )
            input_ids = batch["input_ids"]   # shape: (1, L)
            bos_id = self.tokenizer.bos_token_id
            if bos_id is not None and input_ids[0, 0].item() != bos_id:
                bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                input_ids = torch.cat([bos_tensor, input_ids], dim=1)

            encoded = input_ids

        embed_dev = self.model.get_input_embeddings().weight.device
        encoded = encoded.to(embed_dev)
        prompt_ids = encoded[0].tolist()
        prompt_txt = self.tokenizer.decode(encoded[0], skip_special_tokens=False)
        return encoded, prompt_ids, prompt_txt

    def _build_batch_input_ids(self, questions: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[str], torch.Tensor]:
        """
            批量构建 input_ids, 返回:
            - input_ids  : (batch_size, max_seq_len) tensor on self.device
            - padded_attention_mask : 填充的 attention_mask
            - prompt_ids_list : List[List[int]] 每个样本的prompt_ids
            - prompt_txts : List[str] 每个样本的prompt文本
            - left_pad_lens : 左填充长度
        """
        self.tokenizer.padding_side = "left" # 左填充方便后续处理（直接取最后一个非pad token的隐状态即可）
        # 统一设置pad_token，确保与后续处理一致
        if self.tokenizer.pad_token_id is None:
            # 添加专用的pad_token
            try:
                special_tokens_dict = {'pad_token': '<|pad|>'}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    # 将新添加的pad embedding置零
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except:
                # 如果添加失败，使用eos_token作为pad_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        # 同步到 model / generation_config
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        # self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id

        prompt_txts: List[str] = []
        add_special = True
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            add_special = False
            for q in questions:
                messages = [{"role": "user", "content": q}]
                if "qwen3" in self.model_key.lower():
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,                # 先拿到纯文本模板展开
                        add_generation_prompt=True,
                        enable_thinking=True,
                    )
                else:
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                prompt_txts.append(txt)
        else:
            # 无 chat 模板时，直接用原问题文本
            prompt_txts = list(questions)

        batch_enc = self.tokenizer(
            prompt_txts,
            padding="longest",                 # 自动对齐到本 batch 最长
            return_tensors="pt",
            return_attention_mask=True,
            add_special_tokens=add_special,    # base model 需要添加特殊符号, chat model 不需要
            truncation=False,                  # 避免静默截断
        )
        padded_input_ids = batch_enc["input_ids"]
        padded_attention_mask = batch_enc["attention_mask"]

        prompt_ids_list: List[List[int]] = [
            self.tokenizer.encode(txt, add_special_tokens=add_special)
            for txt in prompt_txts
        ]

        # 转换为tensor并移到设备
        embed_dev = self.model.get_input_embeddings().weight.device
        padded_input_ids = padded_input_ids.to(embed_dev)
        padded_attention_mask = padded_attention_mask.to(embed_dev)

        # 根据 attention_mask 填充长度
        B, T = padded_attention_mask.shape
        prompt_lens = padded_attention_mask.sum(dim=1)              # 每个样本真实长度
        left_pad_lens = T - prompt_lens

        return padded_input_ids, padded_attention_mask, prompt_ids_list, prompt_txts, left_pad_lens

    # 特征提取
    def _extract_features_from_hooks(self, prompt_len: int, output_len: int, needed_layers: List[str], sample_idx: int = 0, hit_limit: bool=False) -> Dict[str, Dict[str, torch.Tensor]]:
        """从hooks收集的隐状态中提取特征"""
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in needed_layers}

        if self.debug_mode:
            console.print(f"[bold magenta]=== 提取特征调试信息 ===[/bold magenta]")
            console.print(f"Prompt长度: {prompt_len}, 输出长度: {output_len}, 样本索引: {sample_idx}")
            console.print(f"捕获的隐状态层数: {len(self.captured_hidden_states)}")
            for k, states in self.captured_hidden_states.items():
                console.print(f"  {k}: {len(states)} 个状态")

        if output_len <= 0:
            console.log(f"[yellow]Warn: no output tokens generated.[/yellow]")
            hidden_dim = self.model.config.hidden_size
            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k]["avg_with_prompt"]  = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
            return layer_features

        for k in needed_layers:
            if k not in self.captured_hidden_states:
                continue

            layer_states = self.captured_hidden_states[k]
            if not layer_states:
                console.log(f"[yellow]Warn: no hidden states captured for {k}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 在自回归生成中，每个前向传播都会产生隐藏状态
            # 第一个前向传播：处理整个prompt + 生成第一个token
            # 后续前向传播：每次处理之前的序列 + 新生成的token

            # 获取第一个前向传播的结果（包含prompt的处理）
            first_step_state = layer_states[0]

            if self.debug_mode:
                console.print(f"[cyan]处理 {k} 层的第一个状态: shape={first_step_state.shape}[/cyan]")

            # 处理张量维度 - 更健壮的维度处理
            if first_step_state.dim() == 3:
                # 批处理模式: [batch_size, seq_len, hidden_dim]
                first_step_hidden = first_step_state[sample_idx]  # [seq_len, hidden_dim]
                if self.debug_mode:
                    console.print(f"[cyan]批处理模式，提取样本 {sample_idx}: {first_step_hidden.shape}[/cyan]")
            elif first_step_state.dim() == 2:
                # 单样本模式: [seq_len, hidden_dim]
                first_step_hidden = first_step_state
                if self.debug_mode:
                    console.print(f"[cyan]单样本模式: {first_step_hidden.shape}[/cyan]")
            else:
                # 一维情况，可能是错误的
                console.print(f"[red]Unexpected hidden state dimension: {first_step_state.dim()}[/red]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 更准确的prompt处理
            if self.is_batch_mode:
                # 在批处理模式中，需要处理左侧padding
                # 第一个前向传播包含了padding + prompt
                # 我们需要提取真正的prompt部分
                if first_step_hidden.shape[0] >= prompt_len:
                    # 从右侧提取prompt_len长度的序列（因为是左padding）
                    actual_prompt_hidden = first_step_hidden[-prompt_len:]
                    if self.debug_mode:
                        console.print(f"[cyan]批处理模式：从右侧提取prompt，原始长度={first_step_hidden.shape[0]}, prompt长度={prompt_len}[/cyan]")
                else:
                    # 如果长度不够，使用全部
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print(f"[yellow]警告：批处理模式长度不足，使用全部序列[/yellow]")
            else:
                # 单样本模式，第一个前向传播的长度应该等于prompt_len
                if first_step_hidden.shape[0] >= prompt_len:
                    actual_prompt_hidden = first_step_hidden[:prompt_len]
                    if self.debug_mode:
                        console.print(f"[cyan]单样本模式：提取前{prompt_len}个token[/cyan]")
                else:
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print(f"[yellow]警告：单样本模式长度不足[/yellow]")

            # 收集输出部分的隐状态
            output_states = []

            # 从第一个前向传播中提取第一个生成token的表示
            # 这是prompt处理后的最后一个位置的隐状态，这部分暂定，感觉加不加影响不大，虽然CoE里面加了

            # if len(actual_prompt_hidden) > 0:
            #     # 修复：第一个生成token的隐状态应该是prompt最后一个位置的隐状态
            #     output_states.append(actual_prompt_hidden[-1:])  # [1, hidden_dim]

            # 处理后续的前向传播结果
            # 每个后续的前向传播都会产生一个新的token的隐状态
            for i in range(1, len(layer_states)):
                # if len(output_states) >= output_len:
                #     break  # 已经收集够了

                step_state = layer_states[i]

                # 处理维度
                if step_state.dim() == 3:
                    # 批处理模式: [batch_size, seq_len, hidden_dim]
                    step_hidden = step_state[sample_idx]  # [seq_len, hidden_dim]
                elif step_state.dim() == 2:
                    # 单样本模式: [seq_len, hidden_dim]
                    step_hidden = step_state
                else:
                    continue

                # 新生成的token的隐状态在序列的最后一个位置
                if len(step_hidden) > 0:
                    output_states.append(step_hidden[-1:])  # [1, hidden_dim]

            if self.debug_mode:
                console.print(f"[yellow]对齐检查: output_len={output_len}, captured={len(output_states)}[/yellow]")

            # 计算真正最后一个 token 的隐状态
            if hit_limit:
                effective_len = min(output_len, len(output_states))
            else:
                effective_len = min(output_len-1, len(output_states))

            # 考虑极短输出（output_len=1）的情况
            if  output_len == 1:
                # 此时输出压根没有隐状态，所以直接用prompt的最后一个token的隐状态
                output_hidden = actual_prompt_hidden[-1:].contiguous()
                # avg_without_prompt：没有生成 token，置零向量（或按你需要的策略）
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                vec_avg_without_prompt = _zero.clone()
                # avg_with_prompt：只对 prompt 做平均
                vec_avg_with_prompt = actual_prompt_hidden.mean(dim=0)
                # last_token：沿用占位（注意语义是 prompt[-1]，必要时可加一个标记位）
                vec_last_token = output_hidden[0]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计（G=1）[/cyan]")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")
                    console.print(f"  avg_with_prompt mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            elif len(output_states) > 0 and effective_len > 0:
                # 只取实际输出长度的隐状态
                output_states = output_states[:effective_len]
                normed = []
                for t in output_states:
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    normed.append(t)
                # 拼接得到 [effective_len, hidden_dim]
                output_hidden = torch.cat(normed, dim=0)

                # output_hidden = torch.cat(output_states, dim=0)  # [output_len, hidden_dim]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计:[/cyan]")
                    console.print(f"  实际输出状态数: {len(output_states)}")
                    console.print(f"  输出hidden shape: {output_hidden.shape}")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")

                # 1. 输出部分token平均值（不包含prompt）
                vec_avg_without_prompt = output_hidden.mean(dim=0)

                # 2. 输出部分token平均值（包含prompt）
                all_hidden = torch.cat([actual_prompt_hidden, output_hidden], dim=0)
                vec_avg_with_prompt = all_hidden.mean(dim=0)

                # 3. 最后一个token的表征
                vec_last_token = output_hidden[-1]

                if self.debug_mode:
                    console.print(f"  avg_with_prompt: mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt: mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token: mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            else:
                console.log(f"[yellow]Warn: insufficient output states for {k}, got {len(output_states)}, need {output_len}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()

        return layer_features

    # 单次前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], str, bool]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 1) 构建 prompt
        prompt = parse_input(item)

        # 2) tokenize
        input_ids, prompt_ids, prompt_txt = self._build_input_ids(prompt)

        # 确保pad_token已设置（与批量处理保持一致）
        if self.tokenizer.pad_token_id is None:
            try:
                special_tokens_dict = {'pad_token': '<|pad|>'}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = False
        self.current_sample_idx = 0
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            attn_mask = torch.ones_like(input_ids)
            gen_out = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                attention_mask=attn_mask,
                # cache_implementation="dynamic"
            )

            # 5) 答案生成解码
            # full_txt = self.tokenizer.decode(gen_out.sequences[0], skip_special_tokens=False)
            # prompt_txt = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
            # answer_txt = full_txt[len(prompt_txt) :].lstrip()

            # 6) 从hooks提取特征
            prompt_len = len(prompt_ids)
            total_sequence = gen_out.sequences[0]
            right_pad_len = self.calculate_right_padding_length(total_sequence)
            valid_total_len = len(total_sequence) - right_pad_len
            output_len = valid_total_len - prompt_len
            if output_len < 0:
                output_len = 0

            valid_sequence = total_sequence[:-right_pad_len] if right_pad_len > 0 else total_sequence
            answer_txt = self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True)

            ####### DEBUG #########
            # console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
            #######################

            # 判断是否命中最大长度限制
            hit_limit = (output_len >= max_new_tokens)
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            ended_with_eos = False
            if output_len > 0:
                last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                if last_gen_id in eos_ids:
                    ended_with_eos = True

            hit_limit = hit_limit and not ended_with_eos    # 仅当命中上限且未遇到EOS时才判不完整
            layer_features = self._extract_features_from_hooks(prompt_len, output_len, needed_layers, sample_idx=0, hit_limit=hit_limit)

        except Exception as e:
            console.print(f"[red]Single sample processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            hidden_dim = self.model.config.hidden_size
            layer_features = {}
            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k] = {
                    "avg_with_prompt": _zero.clone(),
                    "avg_without_prompt": _zero.clone(),
                    "last_token": _zero.clone()
                }
            answer_txt = "Failed"
            hit_limit = False
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return layer_features, answer_txt, hit_limit

    # 批量前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[List[Dict[str, Dict[str, torch.Tensor]]], List[str], List[bool]]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        self.tokenizer.padding_side = "left"

        # 1) 构建批量 prompts
        prompts = [parse_input(item) for item in items]

        # 2) 批量 tokenize（pad_token已在_build_batch_input_ids中统一设置）
        batch_input_ids, padded_attention_mask, prompt_ids_list, prompt_txts, left_pad_lens = self._build_batch_input_ids(prompts)

        # 获取pad_id（此时pad_token_id已经在_build_batch_input_ids中设置好了）
        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = True
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            gen_out = self.model.generate(
                input_ids=batch_input_ids,
                max_new_tokens=max_new_tokens,
                attention_mask=padded_attention_mask,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                # cache_implementation="dynamic"
            )

            # 5) 批量答案解码
            batch_answers = []

            # 6) 批量提取隐状态 - 针对输出部分
            batch_layer_features = []

            hit_limit_flags = [] # 记录是否命中最大长度限制
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            for i, prompt_ids in enumerate(prompt_ids_list):
                # 获取该样本的实际序列信息（考虑左右两侧padding）
                sample_sequence = gen_out.sequences[i]

                # console.print("\n[green]完整输出: [/green]", self.tokenizer.decode(sample_sequence.tolist(), skip_special_tokens=True))

                # 计算左侧padding长度
                left_pad_len = int(left_pad_lens[i].item())

                # 提取去除左侧padding的序列
                actual_sequence = sample_sequence[left_pad_len:]

                # console.print("\n[green]去除左侧padding: [/green]", self.tokenizer.decode(actual_sequence.tolist(), skip_special_tokens=True))

                # 计算右侧padding长度
                right_pad_len = self.calculate_right_padding_length(actual_sequence)
                if right_pad_len > 0:
                    valid_sequence = actual_sequence[:-right_pad_len]
                else:
                    valid_sequence = actual_sequence

                # console.print("\n[green]去除右侧padding: [/green]", self.tokenizer.decode(actual_sequence.tolist(), skip_special_tokens=True))

                # 计算有效序列长度（去除左右padding）
                valid_seq_len = len(valid_sequence)

                # 获取prompt长度和输出长度
                # prompt_len = len(prompt_ids)
                prompt_len = int(padded_attention_mask[i].sum().item())
                output_len = valid_seq_len - prompt_len

                # console.print("去除prompt：", self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True))

                # 确保prompt长度不超过有效序列长度
                if prompt_len > valid_seq_len:
                    console.print(f"[red]错误: 样本{i}的prompt长度({prompt_len})超过有效序列长度({valid_seq_len})[/red]")
                    output_len = 0  # 强制设为0，使用零向量

                ended_with_eos = False
                if output_len > 0:
                    last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                    if last_gen_id in eos_ids:
                        ended_with_eos = True

                hit_limit_tmp = output_len >= max_new_tokens and not ended_with_eos
                hit_limit_flags.append(hit_limit_tmp)
                # 从hooks提取特征
                layer_features = self._extract_features_from_hooks(prompt_len, output_len, needed_layers, sample_idx=i, hit_limit=hit_limit_tmp)
                # 答案解码
                batch_answers.append(self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True))
                batch_layer_features.append(layer_features)

                ####### DEBUG #########
                # console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
                #######################

        except Exception as e:
            console.print(f"[red]Batch processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            batch_answers = ["Failed"] * len(items)
            batch_layer_features = []
            hit_limit_flags = [False] * len(items)
            hidden_dim = self.model.config.hidden_size

            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for i in range(len(items)):
                layer_features = {}
                for k in needed_layers:
                    layer_features[k] = {
                        "avg_with_prompt": _zero.clone(),
                        "avg_without_prompt": _zero.clone(),
                        "last_token": _zero.clone()
                    }
                batch_layer_features.append(layer_features)
            raise  # 重新抛出异常以便上层处理
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return batch_layer_features, batch_answers, hit_limit_flags

    #  把原始数据按 prompt token 长度从长到短排序，尽可能避免长短序列混搭导致大批次OOM
    def _sort_data_by_prompt_len(self, data):
        """
        输入: data(List[Dict])   输出: 排序后的 data(List[Dict])
        """
        def _build_prompt_texts(items):
            # 用你上传的 parse_input 拿最终 prompt 文本
            questions = [parse_input(it) for it in items]

            # 如果 tokenizer 支持 chat template，就按你后续 generate 的方式展开
            prompt_txts = []
            add_special = True
            has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
            if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
                add_special = False
                for q in questions:
                    messages = [{"role": "user", "content": q}]
                    try:
                        if "qwen3" in str(self.model_key).lower():
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                                enable_thinking=True,
                            )
                        else:
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                            )
                    except TypeError:
                        # 老版本 tokenizer 兼容
                        txt = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    prompt_txts.append(txt)
            else:
                # 非 chat tokenizer，直接用文本
                prompt_txts = questions

            return prompt_txts, add_special

        # 纯 CPU 分词测长度，分块以避免一次性处理太大
        LENS_CHUNK = 256
        n = len(data)
        lens = [0] * n
        for i in range(0, n, LENS_CHUNK):
            sub = data[i:i + LENS_CHUNK]
            prompt_txts, add_special = _build_prompt_texts(sub)
            enc = self.tokenizer(
                prompt_txts,
                add_special_tokens=add_special,
                truncation=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            for k, ids in enumerate(enc["input_ids"]):
                lens[i + k] = len(ids)

        # 长 → 短 排序
        order = sorted(range(n), key=lambda t: (lens[t], t), reverse=True)
        return [data[idx] for idx in order]

    def _save_batch_features(
        self,
        feats_avg_with_prompt: Dict[str, List[torch.Tensor]],
        feats_avg_without_prompt: Dict[str, List[torch.Tensor]],
        feats_last_token: Dict[str, List[torch.Tensor]],
        labels: List[int],
        ids: List[str],
        questions: List[str],
        true_answers: List[str],
        pred_answers: List[str],
        categories: List[str],
        options: List[List[str]],
        parsed_answers: List[str],
        incomplete_flags: List[bool],
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
    ) -> None:
        """保存当前批次的特征"""
        def save_features(feats, output_dir, des):
            for k, vec_list in feats.items():
                if not vec_list:
                    continue
                out_path = os.path.join(output_dir, f"{k}_features.pt")

                # 如果文件已存在，加载并合并
                if os.path.exists(out_path):
                    existing_data = torch.load(out_path, map_location="cpu")

                    # Id 去重，防止在 batch 处理完，pt写入但是 id 没有及时写入导致的重复处理
                    existing_ids = existing_data["ids"]
                    existing_id_set = set(existing_ids)

                    if len(vec_list) != len(ids):
                        raise ValueError(f"{k}: vec_list({len(vec_list)}) 与 ids({len(ids)}) 长度不一致")

                    # 以当前要保存的这一路特征 vec_list 为例，先把它和同索引的 meta 列表打包
                    new_records = list(zip(ids, labels, questions, true_answers, pred_answers,
                                        categories, options, parsed_answers, incomplete_flags, vec_list))

                    # 过滤：仅保留还未出现过的 id
                    new_records = [r for r in new_records if r[0] not in existing_id_set]

                    if new_records:
                        n_ids, n_labels, n_qs, n_trues, n_preds, n_cats, n_opts, n_pars, n_incomp, n_vecs = zip(*new_records)
                        tensor_new = torch.stack(list(n_vecs)).cpu()
                        tensor = torch.cat([existing_data["features"], tensor_new], dim=0)

                        merged_ids = existing_ids + list(n_ids)
                        merged_labels = existing_data["labels"].tolist() + list(n_labels)
                        merged_questions = existing_data["questions"] + list(n_qs)
                        merged_true_answers = existing_data["true_answers"] + list(n_trues)
                        merged_pred_answers = existing_data["pred_answers"] + list(n_preds)
                        merged_categories = existing_data["categories"] + list(n_cats)
                        merged_options = existing_data["options"] + list(n_opts)
                        merged_parsed_answers = existing_data["parsed_answers"] + list(n_pars)
                        merged_incomplete = existing_data["incomplete_flags"] + list(n_incomp)
                    else:
                        # 没有新增；直接复用 existing_data
                        tensor = existing_data["features"]
                        merged_ids = existing_ids
                        merged_labels = existing_data["labels"].tolist()
                        merged_questions = existing_data["questions"]
                        merged_true_answers = existing_data["true_answers"]
                        merged_pred_answers = existing_data["pred_answers"]
                        merged_categories = existing_data["categories"]
                        merged_options = existing_data["options"]
                        merged_parsed_answers = existing_data["parsed_answers"]
                        merged_incomplete = existing_data["incomplete_flags"]
                else:
                    # 首次写入
                    tensor = torch.stack(vec_list)
                    merged_ids = ids
                    merged_labels = labels
                    merged_questions = questions
                    merged_true_answers = true_answers
                    merged_pred_answers = pred_answers
                    merged_categories = categories
                    merged_options = options
                    merged_parsed_answers = parsed_answers
                    merged_incomplete = incomplete_flags

                tmp = out_path + ".tmp"
                torch.save({
                    "features": tensor,
                    "labels": torch.tensor(merged_labels),
                    "ids": merged_ids,
                    "questions": merged_questions,
                    "true_answers": merged_true_answers,
                    "pred_answers": merged_pred_answers,
                    "categories": merged_categories,
                    "options": merged_options,
                    "parsed_answers": merged_parsed_answers,
                    "incomplete_flags": merged_incomplete
                }, tmp)
                os.replace(tmp, out_path)
                console.print(f"[green]保存 {k} {des} → {tensor.shape} 到 {out_path}[/green]")

        save_features(feats_avg_with_prompt, avg_with_prompt_dir, des="输出平均值(含prompt)")
        save_features(feats_avg_without_prompt, avg_without_prompt_dir, des="输出平均值(不含prompt)")
        save_features(feats_last_token, last_token_outputdir, des="最后一个token的表征")

    def extract_dataset(
        self,
        dataset_name: str,
        dataset_path: str,
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 2048,
        batch_size: int = 4,
        resume: bool = False,
        checkpoint_root: Optional[str] = None,
    ) -> None:
        os.makedirs(avg_with_prompt_dir, exist_ok=True)
        os.makedirs(avg_without_prompt_dir, exist_ok=True)
        os.makedirs(last_token_outputdir, exist_ok=True)

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(l) for l in f]
        if self.debug_mode:
            data = data[:250]
        console.rule(f"处理数据集 → {dataset_name}")

        needed_layers: List[str] = (
            list(self.layer_map.keys()) if layer_req == "all" else [layer_req]
        )

        # 如果没有指定断点目录，则路径结构与 feats 保持一致：{output_dir}/{model_key}_avg_with_prompt/{dataset_name}/checkpoint_meta.json
        checkpoint_dir = os.path.join(os.path.normpath(checkpoint_root), self.model_key, dataset_name)

        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
        ckpt = CheckpointManager(checkpoint_path, load_existing=resume)

        if not resume:
            ckpt.reset()  # 非断点模式直接从头跑
            # 同时删除历史断点文件和特征文件
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            clear_dir(avg_with_prompt_dir)
            clear_dir(avg_without_prompt_dir)
            clear_dir(last_token_outputdir)

        # 2) 预清洗 + 统计无效 ID
        cleaned_data: List[Dict[str, Any]] = []
        bad_id_count = 0
        for item in data:
            _id = get_item_id(item)
            if _id == "":
                bad_id_count += 1
                continue
            # 统一id描述
            item["_resolved_id"] = _id
            cleaned_data.append(item)
        if bad_id_count:
            print(f"[warn] {bad_id_count} 条样本缺少 id/unique_id，已跳过。")

        ckpt.set_total(len(cleaned_data))

        # 3) 基于 checkpoint 的 processed_ids 进行过滤（在排序前）
        if resume:
            remaining = [it for it in cleaned_data if not ckpt.is_processed(it["_resolved_id"])]
        else:
            remaining = cleaned_data  # 完全从头开始

        if not remaining:
            print("[info] 没有剩余样本可处理。")
            return

        # 2. 按照 prompt token 长度从长到短排序（排序在切片之前）
        data = self._sort_data_by_prompt_len(remaining)
        console.print(f"[cyan]已按prompt长度排序[/cyan]")


        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")

        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]

            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")

            # 用于存储当前批次的特征和标签
            batch_feats_avg_with_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_avg_without_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_labels: List[int] = []
            batch_ids: List[str] = []
            batch_questions: List[str] = []
            batch_true_answers: List[str] = []
            batch_pred_answers: List[str] = []
            batch_categories: List[str] = []
            batch_options: List[List[str]] = []
            batch_parsed_answers: List[str] = []
            batch_incomplete_flags: List[bool] = []

            # 标记当前批次是否完整处理完毕
            batch_completed = False

            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers, hit_limit_flags = self.forward_batch(
                    batch_items,
                    max_new_tokens=max_new_tokens,
                    needed_layers=needed_layers
                )

                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["question"].strip()
                    # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                    # console.print(f"model_answer {start_idx + i}: {answer}")

                    batch_ids.append(item.get("_resolved_id", ""))
                    batch_questions.append(q)
                    batch_true_answers.append(item["answer"])
                    batch_options.append(item["options"])
                    batch_categories.append(item["category"])
                    batch_pred_answers.append(answer)

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                        batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                        batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                    # incomplete 截断
                    if hit_limit_flags[i]:
                        answer = trim_incomplete_sentence(answer)

                    extracted_answer, is_correct = mmlu_parse(answer, item["answer"], hit_limit_flags[i], item["option_answer"])

                    if extracted_answer == "Incomplete" or hit_limit_flags[i]:
                        console.print(f"\n[red]Incomplete answer[/red]")

                        # 测试一下answer输出
                        # print("answer:", answer)
                        # print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                        if extracted_answer is None:
                            print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                        else:
                            if len(extracted_answer) < 500:
                                print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                            else:
                                print("pred_answer:", "Incomplete", "true answer:",item["answer"], "is_correct:", is_correct)
                    else:
                        print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                    # if extracted_answer is None:
                    #     console.print(f"\n[red]pred_answer is None[/red]")

                    # log_cmp(extracted_answer, item["answer"], is_correct)

                    batch_labels.append(int(is_correct))
                    batch_parsed_answers.append(extracted_answer)
                    batch_incomplete_flags.append(hit_limit_flags[i])

                # 标记批次完成
                batch_completed = True

            except Exception as e:
                console.print(f"[red]批次 {batch_idx + 1} 处理失败: {escape(str(e))}[/red]")
                console.print(f"[yellow]回退到单样本处理模式[/yellow]")

                # 仅在OOM时尝试清理
                if "out of memory" in str(e).lower():
                    gc.collect()
                    torch.cuda.empty_cache()

                # 清空当前批次已收集的数据，避免部分数据重复保存
                if batch_ids:
                    console.print(f"[yellow]清空当前批次已收集的数据，避免部分数据重复保存[/yellow]")
                    batch_feats_avg_with_prompt = {k: [] for k in needed_layers}
                    batch_feats_avg_without_prompt = {k: [] for k in needed_layers}
                    batch_feats_last_token = {k: [] for k in needed_layers}
                    batch_labels.clear()
                    batch_ids.clear()
                    batch_questions.clear()
                    batch_true_answers.clear()
                    batch_pred_answers.clear()
                    batch_categories.clear()
                    batch_options.clear()
                    batch_parsed_answers.clear()
                    batch_incomplete_flags.clear()

                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer, hit_limit = self.forward_once(
                            item,
                            max_new_tokens=max_new_tokens,
                            needed_layers=needed_layers
                        )

                        q = item["question"].strip()
                        # console.print(f"\n[green]question {start_idx + i}: {escape(q)}[/green]")
                        # console.print(f"model_answer {start_idx + i}: {answer}", markup=False)

                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(q)
                        batch_true_answers.append(item["answer"])
                        batch_options.append(item["options"])
                        batch_categories.append(item["category"])
                        batch_pred_answers.append(answer)

                        # 添加特征
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                            batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                            batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                        # 被截断的文本需要去掉最后一句不完整的话
                        if hit_limit:
                            answer = trim_incomplete_sentence(answer)

                        extracted_answer, is_correct = mmlu_parse(answer, item["answer"], hit_limit, item["option_answer"])

                        if extracted_answer == "Incomplete" or hit_limit:
                            console.print(f"\n[red]Incomplete answer[/red]")

                            if extracted_answer is None:
                                print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                            else:
                                if len(extracted_answer) < 500:
                                    print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                                else:
                                    print("pred_answer:", "Incomplete", "true answer:",item["answer"], "is_correct:", is_correct)
                        else:
                            print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                        # log_cmp(extracted_answer, item["answer"], is_correct)

                        # if extracted_answer is None:
                        #     console.print(f"\n[red]pred_answer is None[/red]")

                        batch_parsed_answers.append(extracted_answer)
                        # console.print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                        batch_labels.append(int(is_correct))
                        batch_incomplete_flags.append(hit_limit)

                    except Exception as single_e:
                        console.print(f"[red]样本 {start_idx + i} 处理失败: {escape(str(single_e))}[/red]")
                        # 添加空特征以保持索引一致
                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(item["question"].strip())
                        batch_true_answers.append(item["answer"])
                        batch_options.append(item["options"])
                        batch_categories.append(item["category"])
                        batch_pred_answers.append("Failed")
                        batch_parsed_answers.append("Failed")
                        batch_labels.append(0)
                        batch_incomplete_flags.append(False)
                        # 添加零特征
                        hidden_dim = self.model.config.hidden_size
                        _dtype = next(self.model.parameters()).dtype
                        _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(_zero.clone())
                            batch_feats_avg_without_prompt[k].append(_zero.clone())
                            batch_feats_last_token[k].append(_zero.clone())

                # 单样本回退模式下，整个batch处理完成才标记完成
                batch_completed = True

            # 只有当批次完整处理完毕时才保存
            if batch_completed:
                # 保存当前批次的特征
                self._save_batch_features(
                    batch_feats_avg_with_prompt,
                    batch_feats_avg_without_prompt,
                    batch_feats_last_token,
                    batch_labels,
                    batch_ids,
                    batch_questions,
                    batch_true_answers,
                    batch_pred_answers,
                    batch_categories,
                    batch_options,
                    batch_parsed_answers,
                    batch_incomplete_flags,
                    avg_with_prompt_dir,
                    avg_without_prompt_dir,
                    last_token_outputdir,
                )

                # 写入断点（即使 resume=False 也会持续写，确保中断可恢复）
                ckpt.mark_batch_processed(batch_ids)

            # 清理GPU内存
            # torch.cuda.empty_cache()

        print(f"[done] 共 {len(remaining)} 条新样本完成。累计完成 {len(ckpt.processed_ids)}/{ckpt.total_samples}.")


def load_datasets(dataset: List[str], data_dir: str) -> Dict[str, str]:
    paths: Dict[str, str] = {}
    for d in dataset:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}[/red]")
    return paths


def main() -> None:
    p = argparse.ArgumentParser("Hidden feature extractor with checkpoint")
    p.add_argument("--models_root", default=os.path.join(PROJECT_ROOT, "models"), help="本地模型根目录（用于拼接 模型名→路径）")
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT,"new_benchmark","mmlu_pro"))
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT, "feats/mmlu_pro/mmlu_pro_feats_test"))
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="all")
    p.add_argument("--dataset", default=None)
    p.add_argument("--use_few_shot", action="store_true")
    p.add_argument("--layer_type", default="all", choices=["middle", "last", "second_last", "all"])
    p.add_argument("--batch_size", type=int, default=None, help="批量处理的batch size")
    p.add_argument("--resume", action="store_true", help="是否从断点恢复")
    p.add_argument("--checkpoint_root", type=str, default=os.path.join(PROJECT_ROOT, "checkpoints/mmlu_pro/mmlu_pro_feats"), help="checkpoint存储目录",)
    # 自动化 batch 估计参数
    p.add_argument("--auto_batch", choices=["off", "p99", "max"], default="max", help="按分片自动估计 batch（p99 或 max）")
    p.add_argument("--gen_tokens", type=int, default=2048, help="估算时的最大生成长度")
    p.add_argument("--n_gpus", type=int, default=1, help="参与推理的 GPU 数")
    p.add_argument("--per_gpu_vram_gib", type=float, default=80.0, help="每张卡可用显存 GiB，用于估算")
    p.add_argument("--pct", type=float, default=0.99, help="分位数（默认 0.99，用于 pXX 计算）")
    args = p.parse_args()

    global USE_FEW_SHOT
    USE_FEW_SHOT = args.use_few_shot

    os.makedirs(args.models_root, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    if args.dataset is None:
        args.dataset = ["mmlu_pro_train", "mmlu_pro_val", "mmlu_pro_test"]
    if isinstance(args.dataset, str):
        args.dataset = [args.dataset]
    datasets = load_datasets(args.dataset, args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在[/red]")
        return

    if args.checkpoint_root is None:
        checkpoint_root = os.path.join(args.output_dir, "checkpoints")
    else:
        checkpoint_root = args.checkpoint_root
    os.makedirs(checkpoint_root, exist_ok=True)

    # models = MODELS.keys() if args.model == "all" else [args.model]
    if args.model == "all":
        models = []
        for d in sorted(os.listdir(args.models_root)):
            mp = os.path.join(args.models_root, d)
            if os.path.isdir(mp) and (
                os.path.exists(os.path.join(mp, "config.json")) or
                os.path.exists(os.path.join(mp, "tokenizer.json"))
            ):
                models.append(d)
    else:
        models = [m.strip() for m in str(args.model).split(",") if m.strip()]

    for mk in models:
        model_path = os.path.join(args.models_root, mk)

        if not os.path.exists(model_path):
            console.print(f"[red] {mk} 模型目录不存在：{model_path}[/red]")
            continue

        console.rule(f"📦  Processing model — {mk}")
        extractor = HiddenExtractor(mk, model_path, args.device)

        for dname, dpath in datasets.items():
            avg_with_prompt_dir = os.path.join(args.output_dir, mk+"_avg_with_prompt", dname)
            avg_without_prompt_dir = os.path.join(args.output_dir, mk+"_avg_without_prompt", dname)
            last_token_out_dir = os.path.join(args.output_dir, mk+"_last_token", dname)

            # 自动估算 batch
            if args.batch_size is None:
                if args.auto_batch == "off":
                    console.print(f"[yellow]警告: 未指定 --batch_size，且 --auto_batch=off，无法执行，请重新设置 batch[/yellow]")
                    return
            else:
                effective_bsz = max(1, int(args.batch_size)) # 默认使用命令行的 batch_size，防止 auto_batch 为 off 时参数未定义报错

            # 设置最大生成长度
            mk_lower = mk.lower()
            if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
                gen_tok = 8192
            else:
                gen_tok = int(args.gen_tokens)

            if args.auto_batch != "off":
                try:
                    # 计算该分片的 (mx, pxx) —— 按原始功能脚本的单分片接口来做
                    # 注意：这里直接把 dname 作为 dataset_with_split 传入，例如 "math_test"
                    mx, pxx = PTS_POST.token_calculation(
                        model_name=mk,
                        dataset_with_split=dname,
                        pct=args.pct, # 默认 0.99
                    )

                    bsz_p99, bsz_mx = VRAM_POST.batch_estimation(
                        mx=mx,
                        p99=pxx,
                        model_name=mk,
                        gen_tokens=gen_tok,
                        n_gpus=args.n_gpus,
                        per_gpu_vram_gib=args.per_gpu_vram_gib,
                    )

                    auto_bsz = bsz_p99 if args.auto_batch == "p99" else bsz_mx

                    # 若用户也传了 --batch_size，取更小的，稳妥一些
                    effective_bsz = max(1, int(auto_bsz))
                    if args.batch_size is not None:
                        effective_bsz = min(effective_bsz, int(args.batch_size))

                    # 打印选取结果
                    if args.auto_batch == "p99":
                        console.print(f"[cyan]AUTO-BATCH(p99)[/cyan] {mk}/{dname}: p{int(args.pct*100)}={pxx} → use batch_size={effective_bsz}")
                    else:
                        console.print(f"[cyan]AUTO-BATCH(max)[/cyan] {mk}/{dname}: mx={mx} → use batch_size={effective_bsz}")

                except Exception as e:
                    console.print(f"[yellow]AUTO-BATCH 估计失败 {mk}/{dname}：{e}；回退 batch_size={args.batch_size}[/yellow]")
                    if args.batch_size is not None:
                        effective_bsz = max(1, int(args.batch_size))
                    else:
                        # 兜底到 1，或你可以兜底到一个更保守的值
                        effective_bsz = 1
                    console.print(f"[yellow]使用回退 batch_size={effective_bsz}[/yellow]")

            extractor.extract_dataset(
                dname,
                dpath,
                avg_with_prompt_dir,
                avg_without_prompt_dir,
                last_token_out_dir,
                layer_req=args.layer_type,
                max_new_tokens=gen_tok,
                batch_size=effective_bsz,
                resume=args.resume,  # 传递resume参数
                checkpoint_root=checkpoint_root,
            )

    console.rule("程序结束！")
    # logf.close()


if __name__ == "__main__":
    main()
    # cases = [
    #     # 1) 标准英文：the answer is (X)
    #     dict(desc="标准括号字母",
    #          pred="Reasoning...\nTherefore, the answer is (d).",
    #          true="D", opt="", hit=False, expect_ext="D", expect_ok=True),

    #     # 2) 变体：无 the、无括号、句末标点
    #     dict(desc="变体：answer is D.",
    #          pred="... so answer is d.",
    #          true="D", opt="", hit=False, expect_ext="D", expect_ok=True),

    #     # 3) Markdown/加粗
    #     dict(desc="Markdown 加粗 **(A)**",
    #          pred="Final choice:\n**(A)**",
    #          true="A", opt="", hit=False, expect_ext="A", expect_ok=True),

    #     # 4) 中文锚点
    #     dict(desc="中文：答案是 B",
    #          pred="因此，答案是 b。",
    #          true="B", opt="", hit=False, expect_ext="B", expect_ok=True),

    #     # 5) 末行 LaTeX（字母）
    #     dict(desc="末行 \\boxed{C}",
    #          pred="推导...\n\\boxed{C}\n",
    #          true="C", opt="", hit=False, expect_ext="C", expect_ok=True),

    #     # 6) 末行 LaTeX（数字，与选项文本比对）
    #     dict(desc="末行 \\boxed{4}，与 option 文本精确比",
    #          pred="text ...\nHence: \\boxed{4}\n",
    #          true="H", opt="4", hit=False, expect_ext="4", expect_ok=True),

    #     # 7) 末行“The answer is 3.”（非 LaTeX 数字）
    #     dict(desc="末行英文裸数字",
    #          pred="... Therefore the answer is 3.",
    #          true="D", opt="3", hit=False, expect_ext="3", expect_ok=True),

    #     # 8) 末行包含多个独立字母，取“最后一个”
    #     dict(desc="独立字母回退：A B C D",
    #          pred="Some options: A B C D",
    #          true="D", opt="", hit=False, expect_ext="D", expect_ok=True),

    #     # 9) 数学等式中的字母不应被当作选项；但有动词锚点
    #     dict(desc="等式里的 A=3，不应命中；动词锚点选 B",
    #          pred="We have A=3. Please select option (B).",
    #          true="B", opt="", hit=False, expect_ext="B", expect_ok=True),

    #     # 10) 上一行有 LaTeX，末行是数字；应以“末行”为准
    #     dict(desc="上一行 \\boxed{4}，末行 answer is 3（取 3）",
    #          pred="... see \\boxed{4}\nTherefore the answer is 3.",
    #          true="D", opt="3", hit=False, expect_ext="3", expect_ok=True),

    #     # 11) 仅括号字母 + 句点
    #     dict(desc="(d). 形式",
    #          pred="(d).",
    #          true="D", opt="", hit=False, expect_ext="D", expect_ok=True),

    #     # 12) 中文：故选 D 项
    #     dict(desc="中文：故选 D 项",
    #          pred="综上所述，故选 D 项。",
    #          true="D", opt="", hit=False, expect_ext="D", expect_ok=True),

    #     # 13) 命中长度上限 → Incomplete
    #     dict(desc="不完整：命中长度上限",
    #          pred="the answer might be ...",
    #          true="A", opt="", hit=True, expect_ext="Incomplete", expect_ok=False),

    #     # 14) 空文本
    #     dict(desc="空文本返回 None",
    #          pred="",
    #          true="A", opt="", hit=False, expect_ext=None, expect_ok=False),

    #     # 15) 文内多次 the answer is（以“最后一次”/末行优先）
    #     dict(desc="多次出现，取最后一次/末行",
    #          pred="... the answer is (A)\n... reasoning ...\nthe answer is (B)",
    #          true="B", opt="", hit=False, expect_ext="B", expect_ok=True),

    #     dict(desc="多次出现，取最后一次/末行",
    #          pred="[ \text{The answer is } \textbf{A. 39 ft} \]",
    #          true="A", opt="", hit=False, expect_ext="A", expect_ok=True),
    # ]

    # passed = 0
    # for i, c in enumerate(cases, 1):
    #     extracted, ok = mmlu_parse(c["pred"], c["true"], c["hit"], c["opt"])
    #     ok_match = (ok == c["expect_ok"])
    #     ext_match = (extracted == c["expect_ext"])
    #     status = "PASS" if (ok_match and ext_match) else "FAIL"

    #     console.print(f"[{i:02d}] {status} — {c['desc']}")
    #     console.print(f"  pred: {repr(c['pred'])}")
    #     console.print(f"  -> extracted={repr(extracted)}, correct={ok}  "
    #           f"(expect_ext={repr(c['expect_ext'])}, expect_ok={c['expect_ok']})\n")

    #     passed += (status == "PASS")

    # console.print(f"Summary: {passed}/{len(cases)} passed")
