from __future__ import annotations
import argparse
import json
import os
import shutil
import importlib.util
from typing import Dict, List, Optional, Tuple, Iterable, Any, Set

import torch
import gc
from rich.console import Console
from rich.progress import track
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging
import re
import sys
from bbh_lib_prompt import task_description, bbh_lib_prompt
from Qwen_judger import LLMJudge
from rich.markup import escape
from collections import Counter

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from utils.config import CONFIG
from bbh_postprocess import (
    bbh_postprocess,
    bbh_freeform_postprocess_simple,
    _ff__extract_bool_like,
    BBH_TASK_CONFIG,
)

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

logging.set_verbosity_error() # 仅显示ERROR信息，忽略WARNING
console = Console()


USE_FEW_SHOT = False
torch.manual_seed(42)

############# 添加日志记录 ##############################
# logf = open("bbh_forward_hook_logs.txt", "w")
def log_cmp(extracted_answer, item_answer, is_correct):
    """同时打印到控制台并写入txt，自动换行与flush"""
    # 压平答案里的换行，避免一条记录被拆成多行
    ea = str(extracted_answer).replace("\n", "\\n")
    ta = str(item_answer).replace("\n", "\\n")
    line = f"[pred_answer]: {ea} true answer: {ta} is_correct: {is_correct}"
    # print(line)
    logf.write(line + "\n")

# BBH任务类型（多选/自由回答）
BBH_MULTIPLE_CHOICE_SETS = [
    'temporal_sequences',
    'disambiguation_qa',
    'date_understanding',
    'tracking_shuffled_objects_three_objects',
    'penguins_in_a_table',
    'geometric_shapes',
    'snarks',
    'ruin_names',
    'tracking_shuffled_objects_seven_objects',
    'tracking_shuffled_objects_five_objects',
    'logical_deduction_three_objects',
    'hyperbaton',
    'logical_deduction_five_objects',
    'logical_deduction_seven_objects',
    'movie_recommendation',
    'salient_translation_error_detection',
    'reasoning_about_colored_objects',
]
BBH_FREE_FORM_SETS = [
    'multistep_arithmetic_two',
    'navigate',
    'dyck_languages',
    'word_sorting',
    'sports_understanding',
    'boolean_expressions',
    'object_counting',
    'formal_fallacies',
    'causal_judgement',
    'web_of_lies',
]

# 同义词
_CANON_EQUIV = {
    # Yes
    "yes":"yes", "y":"yes", "yeah":"yes", "yep":"yes", "a": "yes",
    "plausible":"yes", "likely":"yes", "affirmative":"yes",
    # No
    "no":"no", "n":"no", "nope":"no", "not": "no", "b": "no",
    "implausible":"no", "unlikely":"no", "unknown":"no", "unplausible": "no",
    # True/False
    "true":"true", "false":"false",
    "t":"true", "f":"false",
    # Valid/Invalid
    "valid":"valid", "invalid":"invalid",
}

_PHRASES_FF = [
    r"\bso\s+the\s+answer\s+is\b",
    r"\bthe\s+answer\s+is\b",
    r"\banswer\s+is\b",
    r"\bfinal\s+answer\s+is\b",
    r"\bfinal\s+answer\b\s*[:：]?",
    r"\banswer\b\s*[:：]",
    r"\bcorrect\s+answer\s+is\b",
    r"(?:因此|所以)?\s*答案\s*(?:为|是)\b",
]

_PHRASES_FF_COMPILED = [re.compile(p, re.I) for p in _PHRASES_FF]
_SENT_END_RE = re.compile(r'(?:[。！？!?]|\.{3,}|…|\.(?=(?:\s|\u3000|$|[”"\'’）\)\]】》])))[\s\u3000]*[”"\'’）\)\]】》]*[\s\u3000]*')

def _contains_answer_phrase(text: str) -> bool:
    return any(p.search(text) for p in _PHRASES_FF_COMPILED)

# 去掉不完整的结尾句子
def trim_incomplete_sentence(ans: str) -> str:
    if not ans:
        return ans
    matches = list(_SENT_END_RE.finditer(ans))
    if not matches:
        return ans
    last_end = matches[-1].end()

    # print(ans[:last_end][:50], ans[last_end:])

    # 只有在最后一句不完整的话中包含可能会被匹配的候选项并且前面的内容中也出现了候选项时才截断
    if _contains_answer_phrase(ans[last_end:]) and _contains_answer_phrase(ans[:last_end]):
        return ans[:last_end].strip()
    return ans

# 去除字符串首尾的标点符号
def _strip_punct(s: str) -> str:
    return re.sub(r'^[^\w]+|[^\w]+$', '', (s or "").strip())

# 判断一个字符串（去首尾标点后）是否是单个词。
def _is_single_word(s: str) -> bool:
    # 去首尾标点后，按字母数字下划线统计 token
    t = _strip_punct(s)
    return bool(t) and not re.search(r"\s", t)

# 对单个单词进行归一化
def _canon_one_word(s: str) -> str | None:
    t = _strip_punct(s).lower()
    # 单个字母的归一化还是要加上的，防止出现 So the answer is N. 这种情况
    return _CANON_EQUIV.get(t)

# 根据任务类型返回组别
def _group_for_task(task: str) -> str:
    t = (task or "").strip().lower()
    cfg = BBH_TASK_CONFIG.get(t, {})

    # 1) 先判 MCQ（有选项数）
    if cfg.get("num_choices"):
        return "mcq"

    # 2) 再判特殊自由题（数值 / Dyck / word_sorting 等）
    special = cfg.get("special")
    if special in {"number", "dyck", "word_sorting"}:
        return special   # 或统一返回 "freeform"，看你总线怎么写

    # 3) 布尔自由题（显式声明了 bool_style）
    if "bool_style" in cfg:
        return cfg["bool_style"][0]  # 'yesno' | 'truefalse' | 'validinvalid'

    # 4) 其它自由题
    return "freeform"

def parse_input(item):
    TASK_DESCRIPTION = task_description[item["category"]]
    task_type = BBH_TASK_CONFIG[item["category"]]["type"]
    if USE_FEW_SHOT:
        _hint = bbh_lib_prompt(item["category"])
        if task_type == "mcq":
            prompt = f"Follow the given examples and answer the question.\n{_hint}\n\nQ: {item['input']}\nA: Let's think step by step. And you must give your final answer(e.g. one of ABCDEFGHIJKLMNOPQRS) by starting with 'So the answer is'.\n"
        else:
            prompt = f"Follow the given examples and answer the question.\n{_hint}\n\nQ: {item['input']}\nA: Let's think step by step. And you must give your final answer by starting with 'So the answer is'.\n"
    else:
        if task_type == "mcq":
            prompt = f"Question: {item['input']}\nA: Let's think step by step. And You must give your final answer(e.g. one of ABCDEFGHIJKLMNOPQRS) by starting with 'So the answer is'.\n"
        else:
            if item["category"] == "word_sorting":
                prompt = f"Question: {item['input']}\nA: Let's think step by step. And You must give your final answer in one line by starting with 'So the answer is'.\n"
            else:
                prompt = f"Question: {item['input']}\nA: Let's think step by step. And You must give your final answer by starting with 'So the answer is'.\n"
    return TASK_DESCRIPTION + "\n" + prompt

_CANON_BOOL = {"yes","no","true","false","valid","invalid"}
# 对 token 进行归一化处理
def _canon_token(s: str):
    if not s: return None
    t = re.sub(r'[ \t\r\n]+', ' ', s).strip().rstrip(' .。!?:;，；：').lower()
    if t in _CANON_BOOL:
        return t
    # 单词同义词也允收
    return _CANON_EQUIV.get(t)

def normalize_bbh_answer(text: str, task: str) -> str:
    # 1) 任务感知的一体化后处理（会自动区分 MCQ/FF、限制选项范围、处理布尔大小写、object_counting裁剪等）
    ans = bbh_postprocess(text, task=task)
    if ans:
        return ans
    # 2) 真兜底（极少触发）：简单截取版
    return bbh_freeform_postprocess_simple(text)

# 在提取出多个单词时，调用LLM-judge进行判断
_JUDGE = LLMJudge()  # 可通过环境变量 SILICONFLOW_API_KEY 配置密钥
def normalize_with_llm_judge_if_needed(text: str, task: str, id: str) -> str:
    ans = normalize_bbh_answer(text, task)
    ans = ans.strip(".,!?;") # 去除无关的标点符号

    # 1) 若提取为空，直接返回空（遵从你的规则：不调用 LLM）
    if not ans or ans.strip() == "":
        return ""

    # 仅对“布尔风格”的自由题允许调用 LLM-judge
    cfg = BBH_TASK_CONFIG.get((task or "").strip().lower(), {})
    is_bool_task = ("bool_style" in cfg)

    # 多词/整句：只有布尔类任务才 judge；否则直接返回 ans
    if not is_bool_task:
        return ans

    # 2) 单词 => 本地等价类规范，不调用 LLM
    if _is_single_word(ans):
        canon = _canon_one_word(ans)
        # 同义词命中则直接返回
        if canon is not None:
            return {
                "yes":"Yes","no":"No",
                "true":"True","false":"False",
                "valid":"Valid","invalid":"Invalid"
            }[canon]

        # 未命中同义词 → 先尝试启发式抽取（零调用）
        guess = _ff__extract_bool_like(ans, task)
        if guess:
            return guess

        # 启发式仍失败 → 调用 LLM Judger
        try:
            group = _group_for_task(task)
            judged, _ = _JUDGE.judge(ans, group=group, task=task, id=id)
            console.print(f"[green]ans: {escape(str(ans))}, judged: {escape(str(judged))}[/green]")
            return judged
        except Exception:
            # 兜底：保持原样；后续 bbh_parse 里仍会尝试组内映射/严格相等
            return ans

    # 3) 多词/整句 => 调 LLM-judge
    try:
        # 新增对 web_of_lies 的解析，减少模型调用
        if task == "web_of_lies" or task == "navigate":
            extract_ans = _ff__extract_bool_like(ans, task)
            if extract_ans in ["Yes", "No", "yes", "no", "YES", "NO"]:
                console.print(f"[green]web_of_lies: ans: {escape(str(ans))}, extract_ans: {escape(extract_ans)}[/green]")
                return extract_ans

        if len(ans.split()) >= 100:
            # 太长了的输出一般都是 incomplete 且大量重复，只截取前5句话给模型
            ans = ans.split(".")
            ans = ". ".join(ans[:min(5, len(ans))])

        group = _group_for_task(task)
        judged, _ = _JUDGE.judge(ans, group=group, task=task, id=id)
        console.print(f"[green]ans: {escape(str(ans))}, judged: {escape(str(judged))}[/green]")
        return judged
    except Exception as e:
        console.print(f"[yellow]LLMJudge failed: {escape(str(e))}; fallback to heuristic[/yellow]")
        # 优先尝试本地布尔抽取；失败就返回原ans
        guess = _ff__extract_bool_like(ans, task)
        return guess or ans


def dyck_parse(pred_clean, true, question):
    # ---- Dyck 专用：要求 “前缀=题目括号串 & 后缀=真值括号串” ----
    # 仅抽取括号字符，保持顺序
    def _br(s: str):
        return re.findall(r"[()\[\]{}<>]", s or "")

    pred_brackets = _br(pred_clean)        # 预测里的括号序列（抽取器已把非括号清掉）
    true_brackets = _br(true)              # 真值补全序列（通常是 1~3 个括号，空格分隔）
    question_brackets = _br(question) if question is not None else None  # 题目括号串

    # 首先假设没有前缀，尝试直接匹配
    if pred_brackets == true_brackets:
        return " ".join(true_brackets), True

    # 如果无法完全匹配，则考虑是否包含题目前缀
    if true_brackets and pred_brackets:
        ok = False
        if question_brackets:  # 有题面时：严格校验 “pt == qt + tt”
            ok = (pred_brackets == question_brackets + true_brackets)
        else:
            # 没拿到题面时的保底：仅做“后缀=真值”检查（不过这部分是用不上的，因为所有 Dyck 题目都有题面）
            ok = (len(pred_brackets) >= len(true_brackets) and pred_brackets[-len(true_brackets):] == true_brackets)

        if ok:
            # 规范化回写为真值形态（空格分隔），日志更干净
            return " ".join(true_brackets), True
        # 未通过 Dyck 校验直接返回“不正确”（落到下方通用比较也会错）
        return " ".join(pred_brackets) if pred_brackets else pred_clean, False

    return (" ".join(pred_brackets) if pred_brackets else pred_clean), False

def _canon_to_group(tok: str, group: str) -> str | None:
    base = _canon_token(tok)             # 先规约成六元组之一；不在集合里 → None
    if base is None:
        return None
    MAP = {
        "yesno": {
            "yes":"yes", "no":"no", "true":"yes", "false":"no", "valid":"yes", "invalid":"no"
        },
        "truefalse": {
            "yes":"true","no":"false","true":"true","false":"false","valid":"true","invalid":"false"
        },
        "validinvalid": {
            "yes":"valid","no":"invalid","true":"valid","false":"invalid","valid":"valid","invalid":"invalid"
        },
    }
    return MAP[group][base]

# question仅在dyck_languages任务中使用
def bbh_parse(pred, true, category, hit_limit=False, question=None, id=None):
    pred_clean = normalize_with_llm_judge_if_needed(pred, category, id)

    if not pred_clean:
        if hit_limit is True:
            return "Incomplete", False
        pred_clean = ""

    if (category or "").strip().lower() == "dyck_languages":
        return dyck_parse(pred_clean, true, question)

    cfg = BBH_TASK_CONFIG.get((category or "").strip().lower(), {})
    is_mcq   = bool(cfg.get("num_choices"))
    bool_grp = cfg.get("bool_style", (None,))[0]  # 'yesno' | 'truefalse' | 'validinvalid' | None

    # 2) MCQ：只做等值比较（不走布尔等价）
    if is_mcq:
        if isinstance(true, list):
            ok = any(pred_clean.upper() == t.upper() for t in true)
        else:
            ok = (pred_clean.upper() == true.upper())
        return pred_clean, ok

    # 3) 布尔自由题：按该组做映射后比较
    if bool_grp:
        def _eq_bool(a, b):
            ca, cb = _canon_to_group(a, bool_grp), _canon_to_group(b, bool_grp)
            if ca is not None and cb is not None:
                return ca == cb
            return a == b  # 任一方非布尔 token，则退回严格相等

        if isinstance(true, list):
            ok = any(_eq_bool(pred_clean, t) for t in true)
        else:
            ok = _eq_bool(pred_clean, true)
        return pred_clean, ok

    # 4) 其他自由题：保持你原有的严格相等策略（或你自定义的数值/排序比较）
    return pred_clean, (pred_clean == true)

# 清空当前目录下的所有文件和子目录
def clear_dir(dir_path: str):
    if not os.path.isdir(dir_path):
        return
    for filename in os.listdir(dir_path):
        file_path = os.path.join(dir_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)       # 删除文件或符号链接
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)   # 删除子目录
        except Exception as e:
            print(f"删除失败: {file_path}, 原因: {e}")

def get_item_id(item: Dict[str, Any]) -> str:
    val = item.get("unique_id", None) # MATH 用 unique_id
    if val is None or val == "":
        val = item.get("question_id") # MMLU_PRO 用 question_id
        if val is None or val == "":
            val = item.get("id") # BBH 用 id
    if val is None:
        return ""
    return str(val).strip()  # "0" 也会被当作有效 ID

# 自动化 batch 估计
def _load_module(module_path: str, module_name: str):
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot import {module_name} from {module_path}")
    mod = importlib.util.module_from_spec(spec)
    import sys as _sys
    _sys.modules[module_name] = mod
    spec.loader.exec_module(mod)
    return mod

PTS_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "prompt_token_states_post.py")
VRAM_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "vram_estimate_post.py")
PTS_POST = _load_module(PTS_PATH, "prompt_token_states_post")
VRAM_POST = _load_module(VRAM_PATH, "vram_estimate_post")
# 断言模块是否加载成功
assert hasattr(PTS_POST, "token_calculation"), "prompt_token_states_post 模块缺少 token_calculation"
assert hasattr(VRAM_POST, "batch_estimation"), "vram_estimate_post 模块缺少 batch_estimation"

# eos 规范化
def _normalize_to_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]

class CheckpointManager:
    """
    将进度以“已处理样本 ID 集合”的方式持久化。
    load_existing=False 时，会忽略并覆盖既有 checkpoint，实现“从头跑但仍写断点”的模式。
    """
    def __init__(self, path: str, load_existing: bool = True):
        self.path = path
        self.processed_ids: Set[str] = set()
        self.completed: bool = False
        self.total_samples: int = 0
        if load_existing:
            self._load()
        else:
            # 显式选择不加载旧 ckpt：如有旧文件，将在第一次 _atomic_save 时被覆盖
            pass

    def _load(self) -> None:
        if os.path.exists(self.path):
            try:
                with open(self.path, "r", encoding="utf-8") as f:
                    data = json.load(f)
                self.processed_ids = set(map(str, data.get("processed_ids", [])))
                self.completed = bool(data.get("completed", False))
                self.total_samples = int(data.get("total_samples", 0))
            except Exception as e:
                print(f"[ckpt] 读取失败，忽略旧文件：{self.path} ({e})")

    def reset(self) -> None:
        """清空内存态的进度；下次 _atomic_save 将覆盖磁盘文件。"""
        self.processed_ids.clear()
        self.completed = False
        # total_samples 保留，后续 set_total 会更新
        # 不立即落盘，避免误覆盖，等到首次写入再覆盖

    def set_total(self, n: int) -> None:
        if self.total_samples != n:
            self.total_samples = int(n)
            self._atomic_save()

    def mark_batch_processed(self, ids: Iterable[str]) -> None:
        for _id in ids:
            sid = str(_id).strip() if _id is not None else ""
            if sid:
                self.processed_ids.add(sid)
        if self.total_samples > 0 and len(self.processed_ids) >= self.total_samples:
            self.completed = True
        self._atomic_save()

    def is_processed(self, _id: Any) -> bool:
        sid = str(_id).strip() if _id is not None else ""
        return bool(sid) and (sid in self.processed_ids)

    def _atomic_save(self) -> None:
        tmp = self.path + ".tmp"
        data = {
            "processed_ids": sorted(self.processed_ids),
            "completed": self.completed,
            "total_samples": self.total_samples,
        }
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(tmp, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False)
        os.replace(tmp, self.path)


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        try:
            attn_implementation = "flash_attention_2"
            if "gemma" in model_path:
                attn_implementation = "sdpa"
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                # use_flash_attn=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation=attn_implementation
            )
        except Exception as e:
            console.print(f"[yellow]Warning: Failed to load model with auto device map and flash-attention: {escape(str(e))}[/yellow]")
           # 防止某些模型不支持flash-attention
            try:
                # 防止某些模型不支持flash-attention
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )
            except Exception as e2:
                console.print(f"[yellow]Warning: Failed to load model with auto device map: {escape(str(e2))}[/yellow]")
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map=None,             # 单设备
                ).to(self.device)

        self.model.eval()
        self.layer_map = self._build_layer_index()
        self.calculate_right_padding_length = self._calculate_right_padding_length

        # hook相关属性
        self.hooks = []
        self.captured_hidden_states = {}
        self.is_batch_mode = False
        self.current_sample_idx = 0

        # 调试属性
        self.debug_mode = False
        self.debug_info = {}

    def _build_layer_index(self) -> Dict[str, int]:
        n_layers = self.model.config.num_hidden_layers
        mapping = {
            "middle": n_layers // 2,
            "last": n_layers - 1,
            "second_last": n_layers - 2,
        }
        console.print(
            f"总层数: {n_layers} → middle={mapping['middle']}, "
            f"second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping

    def _calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0
        pad_id = self.tokenizer.pad_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

    # 钩子注册
    def _register_hooks(self, needed_layers: List[str]) -> None:
        """注册forward hooks来捕获指定层的隐状态"""
        self._clear_hooks()  # 先清除之前的hooks
        self.captured_hidden_states = {k: [] for k in needed_layers}

        def create_hook(layer_name: str):
            def hook_fn(module, input, output):

                if isinstance(output, tuple):
                    hidden_state = output[0]  # 通常隐状态是第一个元素
                else:
                    hidden_state = output

                # 复制到CPU后立即删除GPU引用，但使用更安全的方式
                hidden_state_cpu = hidden_state.detach().to('cpu')

                # 显式删除GPU引用以节省显存，但要确保Flash Attention已经完成计算
                if hidden_state.is_cuda:
                    del hidden_state

                # if self.debug_mode:
                #     console.print(f"[cyan]{layer_name} hidden state shape: {hidden_state_cpu.shape}[/cyan]")

                # 存储隐藏状态 - 不需要步骤计数，按顺序存储即可
                self.captured_hidden_states[layer_name].append(hidden_state_cpu)

            return hook_fn

        # 根据模型架构获取正确的层
        if hasattr(self.model, 'model'):  # 如 LlamaForCausalLM
            layers = self.model.model.layers
        elif hasattr(self.model, 'transformer'):  # 如 GPT类模型
            layers = self.model.transformer.h
        else:
            # 尝试其他可能的属性名
            for attr in ['layers', 'h', 'decoder_layers']:
                if hasattr(self.model, attr):
                    layers = getattr(self.model, attr)
                    break
            else:
                raise AttributeError("无法找到模型的层结构")

        # 修复问题1：直接根据layer_name计算索引，避免重复计算
        for layer_name in needed_layers:
            if layer_name == "last":
                layer_idx = len(layers)
            elif layer_name == "second_last":
                layer_idx = len(layers) - 2
            elif layer_name == "middle":
                layer_idx = len(layers) // 2
            else:
                # 如果是其他层名，可以从layer_map获取
                layer_idx = self.layer_map.get(layer_name, 0)

            if 0 <= layer_idx < len(layers):
                hook = layers[layer_idx].register_forward_hook(create_hook(layer_name))
                self.hooks.append(hook)
                # console.print(f"注册hook for {layer_name} (layer {layer_idx})")

    def _clear_hooks(self) -> None:
        """清除所有hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        # 注意：这里不清除captured_hidden_states，因为后续还需要使用
        # self.captured_hidden_states.clear()

    # 方便统一处理chat/base模型
    def _build_input_ids(self, question: str) -> Tuple[torch.Tensor, list[int], str]:
        """
        统一构建 input_ids, 返回:
        - input_ids  : (1, seq_len) tensor on self.device
        - prompt_ids : list[int] 同一内容, 用于后续定位
        - prompt_txt : str, decode 后的完整 prompt 文本
        """
        # chat / instruct 模型：tokenizer 支持 chat template
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            prompt_struct = [{"role": "user", "content": question}]
            if "qwen3" in self.model_key.lower():
                try:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=True,
                    )
                except:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
            else:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
        else:
            batch = self.tokenizer(
                question,
                add_special_tokens=True,      # 先让tokenizer自己加他心目中的特殊符号
                return_tensors="pt",
                truncation=False,
            )
            input_ids = batch["input_ids"]   # shape: (1, L)
            bos_id = self.tokenizer.bos_token_id
            if bos_id is not None and input_ids[0, 0].item() != bos_id:
                bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                input_ids = torch.cat([bos_tensor, input_ids], dim=1)

            encoded = input_ids

        embed_dev = self.model.get_input_embeddings().weight.device
        encoded = encoded.to(embed_dev)
        prompt_ids = encoded[0].tolist()
        prompt_txt = self.tokenizer.decode(encoded[0], skip_special_tokens=False)
        return encoded, prompt_ids, prompt_txt

    def _build_batch_input_ids(self, questions: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[str], torch.Tensor]:
        """
            批量构建 input_ids, 返回:
            - input_ids  : (batch_size, max_seq_len) tensor on self.device
            - padded_attention_mask : 填充的 attention_mask
            - prompt_ids_list : List[List[int]] 每个样本的prompt_ids
            - prompt_txts : List[str] 每个样本的prompt文本
            - left_pad_lens : 左填充长度
        """
        self.tokenizer.padding_side = "left" # 左填充方便后续处理（直接取最后一个非pad token的隐状态即可）
        # 统一设置pad_token，确保与后续处理一致
        if self.tokenizer.pad_token_id is None:
            # 添加专用的pad_token
            try:
                special_tokens_dict = {'pad_token': '<|pad|>'}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    # 将新添加的pad embedding置零
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except:
                # 如果添加失败，使用eos_token作为pad_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        # 同步到 model / generation_config
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        # self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id

        prompt_txts: List[str] = []
        add_special = True
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            add_special = False
            for q in questions:
                messages = [{"role": "user", "content": q}]
                if "qwen3" in self.model_key.lower():
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,                # 先拿到纯文本模板展开
                        add_generation_prompt=True,
                        enable_thinking=True,
                    )
                else:
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                prompt_txts.append(txt)
        else:
            # 无 chat 模板时，直接用原问题文本
            prompt_txts = list(questions)

        batch_enc = self.tokenizer(
            prompt_txts,
            padding="longest",                 # 自动对齐到本 batch 最长
            return_tensors="pt",
            return_attention_mask=True,
            add_special_tokens=add_special,    # base model 需要添加特殊符号, chat model 不需要
            truncation=False,                  # 避免静默截断
        )
        padded_input_ids = batch_enc["input_ids"]
        padded_attention_mask = batch_enc["attention_mask"]

        prompt_ids_list: List[List[int]] = [
            self.tokenizer.encode(txt, add_special_tokens=add_special)
            for txt in prompt_txts
        ]

        # 转换为tensor并移到设备
        embed_dev = self.model.get_input_embeddings().weight.device
        padded_input_ids = padded_input_ids.to(embed_dev)
        padded_attention_mask = padded_attention_mask.to(embed_dev)

        # 根据 attention_mask 填充长度
        B, T = padded_attention_mask.shape
        prompt_lens = padded_attention_mask.sum(dim=1)              # 每个样本真实长度
        left_pad_lens = T - prompt_lens

        return padded_input_ids, padded_attention_mask, prompt_ids_list, prompt_txts, left_pad_lens

    # 特征提取
    def _extract_features_from_hooks(self, prompt_len: int, output_len: int, needed_layers: List[str], sample_idx: int = 0, hit_limit: bool=False) -> Dict[str, Dict[str, torch.Tensor]]:
        """从hooks收集的隐状态中提取特征"""
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in needed_layers}

        if self.debug_mode:
            console.print(f"[bold magenta]=== 提取特征调试信息 ===[/bold magenta]")
            console.print(f"Prompt长度: {prompt_len}, 输出长度: {output_len}, 样本索引: {sample_idx}")
            console.print(f"捕获的隐状态层数: {len(self.captured_hidden_states)}")
            for k, states in self.captured_hidden_states.items():
                console.print(f"  {k}: {len(states)} 个状态")

        if output_len <= 0:
            console.log(f"[yellow]Warn: no output tokens generated.[/yellow]")
            hidden_dim = self.model.config.hidden_size
            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k]["avg_with_prompt"]  = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
            return layer_features

        for k in needed_layers:
            if k not in self.captured_hidden_states:
                continue

            layer_states = self.captured_hidden_states[k]
            if not layer_states:
                console.log(f"[yellow]Warn: no hidden states captured for {k}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 在自回归生成中，每个前向传播都会产生隐藏状态
            # 第一个前向传播：处理整个prompt + 生成第一个token
            # 后续前向传播：每次处理之前的序列 + 新生成的token

            # 获取第一个前向传播的结果（包含prompt的处理）
            first_step_state = layer_states[0]

            if self.debug_mode:
                console.print(f"[cyan]处理 {k} 层的第一个状态: shape={first_step_state.shape}[/cyan]")

            # 处理张量维度 - 更健壮的维度处理
            if first_step_state.dim() == 3:
                # 批处理模式: [batch_size, seq_len, hidden_dim]
                first_step_hidden = first_step_state[sample_idx]  # [seq_len, hidden_dim]
                if self.debug_mode:
                    console.print(f"[cyan]批处理模式，提取样本 {sample_idx}: {first_step_hidden.shape}[/cyan]")
            elif first_step_state.dim() == 2:
                # 单样本模式: [seq_len, hidden_dim]
                first_step_hidden = first_step_state
                if self.debug_mode:
                    console.print(f"[cyan]单样本模式: {first_step_hidden.shape}[/cyan]")
            else:
                # 一维情况，可能是错误的
                console.print(f"[red]Unexpected hidden state dimension: {first_step_state.dim()}[/red]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            # 更准确的prompt处理
            if self.is_batch_mode:
                # 在批处理模式中，需要处理左侧padding
                # 第一个前向传播包含了padding + prompt
                # 我们需要提取真正的prompt部分
                if first_step_hidden.shape[0] >= prompt_len:
                    # 从右侧提取prompt_len长度的序列（因为是左padding）
                    actual_prompt_hidden = first_step_hidden[-prompt_len:]
                    if self.debug_mode:
                        console.print(f"[cyan]批处理模式：从右侧提取prompt，原始长度={first_step_hidden.shape[0]}, prompt长度={prompt_len}[/cyan]")
                else:
                    # 如果长度不够，使用全部
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print(f"[yellow]警告：批处理模式长度不足，使用全部序列[/yellow]")
            else:
                # 单样本模式，第一个前向传播的长度应该等于prompt_len
                if first_step_hidden.shape[0] >= prompt_len:
                    actual_prompt_hidden = first_step_hidden[:prompt_len]
                    if self.debug_mode:
                        console.print(f"[cyan]单样本模式：提取前{prompt_len}个token[/cyan]")
                else:
                    actual_prompt_hidden = first_step_hidden
                    if self.debug_mode:
                        console.print(f"[yellow]警告：单样本模式长度不足[/yellow]")

            # 收集输出部分的隐状态
            output_states = []

            # 从第一个前向传播中提取第一个生成token的表示
            # 这是prompt处理后的最后一个位置的隐状态，这部分暂定，感觉加不加影响不大，虽然CoE里面加了

            # if len(actual_prompt_hidden) > 0:
            #     # 修复：第一个生成token的隐状态应该是prompt最后一个位置的隐状态
            #     output_states.append(actual_prompt_hidden[-1:])  # [1, hidden_dim]

            # 处理后续的前向传播结果
            # 每个后续的前向传播都会产生一个新的token的隐状态
            for i in range(1, len(layer_states)):
                # if len(output_states) >= output_len:
                #     break  # 已经收集够了

                step_state = layer_states[i]

                # 处理维度
                if step_state.dim() == 3:
                    # 批处理模式: [batch_size, seq_len, hidden_dim]
                    step_hidden = step_state[sample_idx]  # [seq_len, hidden_dim]
                elif step_state.dim() == 2:
                    # 单样本模式: [seq_len, hidden_dim]
                    step_hidden = step_state
                else:
                    continue

                # 新生成的token的隐状态在序列的最后一个位置
                if len(step_hidden) > 0:
                    output_states.append(step_hidden[-1:])  # [1, hidden_dim]

            if self.debug_mode:
                console.print(f"[yellow]对齐检查: output_len={output_len}, captured={len(output_states)}[/yellow]")

            # 计算真正最后一个 token 的隐状态
            if hit_limit:
                effective_len = min(output_len, len(output_states))
            else:
                effective_len = min(output_len-1, len(output_states))

            # 考虑极短输出（output_len=1）的情况
            if  output_len == 1:
                # 此时输出压根没有隐状态，所以直接用prompt的最后一个token的隐状态
                output_hidden = actual_prompt_hidden[-1:].contiguous()
                # avg_without_prompt：没有生成 token，置零向量（或按你需要的策略）
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                vec_avg_without_prompt = _zero.clone()
                # avg_with_prompt：只对 prompt 做平均
                vec_avg_with_prompt = actual_prompt_hidden.mean(dim=0)
                # last_token：沿用占位（注意语义是 prompt[-1]，必要时可加一个标记位）
                vec_last_token = output_hidden[0]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计（G=1）[/cyan]")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")
                    console.print(f"  avg_with_prompt mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            elif len(output_states) > 0 and effective_len > 0:
                # 只取实际输出长度的隐状态
                output_states = output_states[:effective_len]
                normed = []
                for t in output_states:
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    normed.append(t)
                # 拼接得到 [effective_len, hidden_dim]
                output_hidden = torch.cat(normed, dim=0)

                # output_hidden = torch.cat(output_states, dim=0)  # [output_len, hidden_dim]

                if self.debug_mode:
                    console.print(f"[cyan]{k} 层输出状态统计:[/cyan]")
                    console.print(f"  实际输出状态数: {len(output_states)}")
                    console.print(f"  输出hidden shape: {output_hidden.shape}")
                    console.print(f"  Prompt hidden shape: {actual_prompt_hidden.shape}")

                # 1. 输出部分token平均值（不包含prompt）
                vec_avg_without_prompt = output_hidden.mean(dim=0)

                # 2. 输出部分token平均值（包含prompt）
                all_hidden = torch.cat([actual_prompt_hidden, output_hidden], dim=0)
                vec_avg_with_prompt = all_hidden.mean(dim=0)

                # 3. 最后一个token的表征
                vec_last_token = output_hidden[-1]

                if self.debug_mode:
                    console.print(f"  avg_with_prompt: mean={vec_avg_with_prompt.mean():.4f}, std={vec_avg_with_prompt.std():.4f}")
                    console.print(f"  avg_without_prompt: mean={vec_avg_without_prompt.mean():.4f}, std={vec_avg_without_prompt.std():.4f}")
                    console.print(f"  last_token: mean={vec_last_token.mean():.4f}, std={vec_last_token.std():.4f}")

                layer_features[k]["avg_with_prompt"] = vec_avg_with_prompt
                layer_features[k]["avg_without_prompt"] = vec_avg_without_prompt
                layer_features[k]["last_token"] = vec_last_token
            else:
                console.log(f"[yellow]Warn: insufficient output states for {k}, got {len(output_states)}, need {output_len}.[/yellow]")
                hidden_dim = self.model.config.hidden_size
                _dtype = next(self.model.parameters()).dtype
                _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                layer_features[k]["avg_with_prompt"] = _zero.clone()
                layer_features[k]["avg_without_prompt"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()

        return layer_features

    # 单次前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], str, bool]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 1) 构建 prompt
        prompt = parse_input(item)

        # 2) tokenize
        input_ids, prompt_ids, prompt_txt = self._build_input_ids(prompt)

        # 确保pad_token已设置（与批量处理保持一致）
        if self.tokenizer.pad_token_id is None:
            try:
                special_tokens_dict = {'pad_token': '<|pad|>'}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = False
        self.current_sample_idx = 0
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            attn_mask = torch.ones_like(input_ids)
            gen_out = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                attention_mask=attn_mask,
            )

            # 5) 答案生成解码

            # 6) 从hooks提取特征
            prompt_len = len(prompt_ids)
            total_sequence = gen_out.sequences[0]
            right_pad_len = self.calculate_right_padding_length(total_sequence)
            valid_total_len = len(total_sequence) - right_pad_len
            output_len = valid_total_len - prompt_len
            if output_len < 0:
                output_len = 0

            valid_sequence = total_sequence[:-right_pad_len] if right_pad_len > 0 else total_sequence
            answer_txt = self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True)

            ####### DEBUG #########
            # console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
            #######################

            # 判断是否命中最大长度限制
            hit_limit = (output_len >= max_new_tokens)
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            ended_with_eos = False
            if output_len > 0:
                last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                if last_gen_id in eos_ids:
                    ended_with_eos = True

            hit_limit = hit_limit and not ended_with_eos    # 仅当命中上限且未遇到EOS时才判不完整
            layer_features = self._extract_features_from_hooks(prompt_len, output_len, needed_layers, sample_idx=0, hit_limit=hit_limit)

        except Exception as e:
            console.print(f"[red]Single sample processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            hidden_dim = self.model.config.hidden_size
            layer_features = {}
            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for k in needed_layers:
                layer_features[k] = {
                    "avg_with_prompt": _zero.clone(),
                    "avg_without_prompt": _zero.clone(),
                    "last_token": _zero.clone()
                }
            answer_txt = "Failed"
            hit_limit = False
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return layer_features, answer_txt, hit_limit

    # 批量前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 2048,
        needed_layers: List[str] = None,
    ) -> Tuple[List[Dict[str, Dict[str, torch.Tensor]]], List[str], List[bool]]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        self.tokenizer.padding_side = "left"

        # 1) 构建批量 prompts
        prompts = [parse_input(item) for item in items]

        # 2) 批量 tokenize（pad_token已在_build_batch_input_ids中统一设置）
        batch_input_ids, padded_attention_mask, prompt_ids_list, prompt_txts, left_pad_lens = self._build_batch_input_ids(prompts)

        # 获取pad_id（此时pad_token_id已经在_build_batch_input_ids中设置好了）
        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 8192:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=8192，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = True
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            gen_out = self.model.generate(
                input_ids=batch_input_ids,
                max_new_tokens=max_new_tokens,
                attention_mask=padded_attention_mask,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
            )

            # 5) 批量答案解码
            batch_answers = []

            # 6) 批量提取隐状态 - 针对输出部分
            batch_layer_features = []

            hit_limit_flags = [] # 记录是否命中最大长度限制
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            for i, prompt_ids in enumerate(prompt_ids_list):
                # 获取该样本的实际序列信息（考虑左右两侧padding）
                sample_sequence = gen_out.sequences[i]

                # 计算左侧padding长度
                left_pad_len = int(left_pad_lens[i].item())

                # 提取去除左侧padding的序列
                actual_sequence = sample_sequence[left_pad_len:]

                # 计算右侧padding长度
                right_pad_len = self.calculate_right_padding_length(actual_sequence)
                if right_pad_len > 0:
                    valid_sequence = actual_sequence[:-right_pad_len]
                else:
                    valid_sequence = actual_sequence

                # 计算有效序列长度（去除左右padding）
                valid_seq_len = len(valid_sequence)

                # 获取prompt长度和输出长度
                # prompt_len = len(prompt_ids)
                prompt_len = int(padded_attention_mask[i].sum().item())
                output_len = valid_seq_len - prompt_len

                # 确保prompt长度不超过有效序列长度
                if prompt_len > valid_seq_len:
                    console.print(f"[red]错误: 样本{i}的prompt长度({prompt_len})超过有效序列长度({valid_seq_len})[/red]")
                    output_len = 0  # 强制设为0，使用零向量

                ended_with_eos = False
                if output_len > 0:
                    last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                    if last_gen_id in eos_ids:
                        ended_with_eos = True

                hit_limit_tmp = output_len >= max_new_tokens and not ended_with_eos
                hit_limit_flags.append(hit_limit_tmp)

                # 从hooks提取特征
                layer_features = self._extract_features_from_hooks(prompt_len, output_len, needed_layers, sample_idx=i, hit_limit=hit_limit_tmp)
                # 答案解码
                batch_answers.append(self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True))
                batch_layer_features.append(layer_features)

                ####### DEBUG #########
                # console.print(f"[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=False)}[/green]")
                #######################

        except Exception as e:
            console.print(f"[red]Batch processing failed: {escape(str(e))}[/red]")
            # 提供默认值
            batch_answers = ["Failed"] * len(items)
            batch_layer_features = []
            hit_limit_flags = [False] * len(items)
            hidden_dim = self.model.config.hidden_size

            _dtype = next(self.model.parameters()).dtype
            _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
            for i in range(len(items)):
                layer_features = {}
                for k in needed_layers:
                    layer_features[k] = {
                        "avg_with_prompt": _zero.clone(),
                        "avg_without_prompt": _zero.clone(),
                        "last_token": _zero.clone()
                    }
                batch_layer_features.append(layer_features)
            raise  # 重新抛出异常以便上层处理
        finally:
            # 7) 清理hooks - 修复：确保在finally中清理
            self._clear_hooks()
            # 清理captured_hidden_states
            self.captured_hidden_states.clear()

        return batch_layer_features, batch_answers, hit_limit_flags

    #  把原始数据按 prompt token 长度从长到短排序，尽可能避免长短序列混搭导致大批次OOM
    def _sort_data_by_prompt_len(self, data):
        """
        输入: data(List[Dict])   输出: 排序后的 data(List[Dict])
        """
        def _build_prompt_texts(items):
            # 用你上传的 parse_input 拿最终 prompt 文本
            questions = [parse_input(it) for it in items]

            # 如果 tokenizer 支持 chat template，就按你后续 generate 的方式展开
            prompt_txts = []
            add_special = True
            has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
            if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
                add_special = False
                for q in questions:
                    messages = [{"role": "user", "content": q}]
                    try:
                        if "qwen3" in str(self.model_key).lower():
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                                enable_thinking=True,
                            )
                        else:
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                            )
                    except TypeError:
                        # 老版本 tokenizer 兼容
                        txt = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    prompt_txts.append(txt)
            else:
                # 非 chat tokenizer，直接用文本
                prompt_txts = questions

            return prompt_txts, add_special

        # 纯 CPU 分词测长度，分块以避免一次性处理太大
        LENS_CHUNK = 256
        n = len(data)
        lens = [0] * n
        for i in range(0, n, LENS_CHUNK):
            sub = data[i:i + LENS_CHUNK]
            prompt_txts, add_special = _build_prompt_texts(sub)
            enc = self.tokenizer(
                prompt_txts,
                add_special_tokens=add_special,
                truncation=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            for k, ids in enumerate(enc["input_ids"]):
                lens[i + k] = len(ids)

        # 长 → 短 排序
        order = sorted(range(n), key=lambda t: (lens[t], t), reverse=True)
        return [data[idx] for idx in order]

    def _save_batch_features(
        self,
        feats_avg_with_prompt: Dict[str, List[torch.Tensor]],
        feats_avg_without_prompt: Dict[str, List[torch.Tensor]],
        feats_last_token: Dict[str, List[torch.Tensor]],
        labels: List[int],
        ids: List[str],
        questions: List[str],
        true_answers: List[str],
        pred_answers: List[str],
        categories: List[str],
        incomplete_flags: List[bool],
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
    ) -> None:
        """保存当前批次的特征"""
        def save_features(feats, output_dir, des):
            for k, vec_list in feats.items():
                if not vec_list:
                    continue
                out_path = os.path.join(output_dir, f"{k}_features.pt")

                # 如果文件已存在，加载并合并
                if os.path.exists(out_path):
                    existing_data = torch.load(out_path, map_location="cpu")

                    # Id 去重，防止在 batch 处理完，pt写入但是 id 没有及时写入导致的重复处理
                    existing_ids = existing_data["ids"]
                    existing_id_set = set(existing_ids)

                    if len(vec_list) != len(ids):
                        raise ValueError(f"{k}: vec_list({len(vec_list)}) 与 ids({len(ids)}) 长度不一致")

                    # 以当前要保存的这一路特征 vec_list 为例，先把它和同索引的 meta 列表打包
                    new_records = list(zip(ids, labels, questions, true_answers, pred_answers,
                                        categories, incomplete_flags, vec_list))

                    # 过滤：仅保留还未出现过的 id
                    new_records = [r for r in new_records if r[0] not in existing_id_set]

                    if new_records:
                        n_ids, n_labels, n_qs, n_trues, n_preds, n_cats, n_incomp, n_vecs = zip(*new_records)
                        tensor_new = torch.stack(list(n_vecs)).cpu()
                        tensor = torch.cat([existing_data["features"], tensor_new], dim=0)

                        merged_ids = existing_ids + list(n_ids)
                        merged_labels = existing_data["labels"].tolist() + list(n_labels)
                        merged_questions = existing_data["questions"] + list(n_qs)
                        merged_true_answers = existing_data["true_answers"] + list(n_trues)
                        merged_pred_answers = existing_data["pred_answers"] + list(n_preds)
                        merged_categories = existing_data["categories"] + list(n_cats)
                        merged_incomplete = existing_data["incomplete_flags"] + list(n_incomp)
                    else:
                        # 没有新增；直接复用 existing_data
                        tensor = existing_data["features"]
                        merged_ids = existing_ids
                        merged_labels = existing_data["labels"].tolist()
                        merged_questions = existing_data["questions"]
                        merged_true_answers = existing_data["true_answers"]
                        merged_pred_answers = existing_data["pred_answers"]
                        merged_categories = existing_data["categories"]
                        merged_incomplete = existing_data["incomplete_flags"]
                else:
                    # 首次写入
                    tensor = torch.stack(vec_list)
                    merged_ids = ids
                    merged_labels = labels
                    merged_questions = questions
                    merged_true_answers = true_answers
                    merged_pred_answers = pred_answers
                    merged_categories = categories
                    merged_incomplete = incomplete_flags

                tmp = out_path + ".tmp"
                torch.save({
                    "features": tensor,
                    "labels": torch.tensor(merged_labels),
                    "ids": merged_ids,
                    "questions": merged_questions,
                    "true_answers": merged_true_answers,
                    "pred_answers": merged_pred_answers,
                    "categories": merged_categories,
                    "incomplete_flags": merged_incomplete
                }, tmp)
                os.replace(tmp, out_path)
                console.print(f"[green]保存 {k} {des} → {tensor.shape} 到 {out_path}[/green]")

        save_features(feats_avg_with_prompt, avg_with_prompt_dir, des="输出平均值(含prompt)")
        save_features(feats_avg_without_prompt, avg_without_prompt_dir, des="输出平均值(不含prompt)")
        save_features(feats_last_token, last_token_outputdir, des="最后一个token的表征")

    def extract_dataset(
        self,
        dataset_name: str,
        dataset_path: str,
        avg_with_prompt_dir: str,
        avg_without_prompt_dir: str,
        last_token_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 2048,
        batch_size: int = 4,
        resume: bool = False,
        checkpoint_root: Optional[str] = None,
    ) -> None:
        os.makedirs(avg_with_prompt_dir, exist_ok=True)
        os.makedirs(avg_without_prompt_dir, exist_ok=True)
        os.makedirs(last_token_outputdir, exist_ok=True)

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(l) for l in f]
        console.rule(f"处理数据集 → {dataset_name}")

        ################## Qwen Judger测试 ##################
        # Judger_type = ["navigate", "sports_understanding", "boolean_expressions", "formal_fallacies", "causal_judgement", "web_of_lies"]
        # # 只取 Jugder type 类别的数据进行测试
        # data = [item for item in data if item["category"] in Judger_type]
        # print(f"[info] 仅处理 Judger 类型数据，共 {len(data)} 条样本。")
        #####################################################

        needed_layers: List[str] = (
            list(self.layer_map.keys()) if layer_req == "all" else [layer_req]
        )

        # 如果没有指定断点目录，则路径结构与 feats 保持一致：{output_dir}/{model_key}_avg_with_prompt/{dataset_name}/checkpoint_meta.json
        checkpoint_dir = os.path.join(os.path.normpath(checkpoint_root), self.model_key, dataset_name)

        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
        ckpt = CheckpointManager(checkpoint_path, load_existing=resume)

        if not resume:
            ckpt.reset()  # 非断点模式直接从头跑
            # 同时删除历史断点文件和特征文件
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            clear_dir(avg_with_prompt_dir)
            clear_dir(avg_without_prompt_dir)
            clear_dir(last_token_outputdir)

        # 2) 预清洗 + 统计无效 ID
        cleaned_data: List[Dict[str, Any]] = []
        bad_id_count = 0
        for item in data:
            _id = get_item_id(item)
            if _id == "":
                bad_id_count += 1
                continue
            # 统一id描述
            item["_resolved_id"] = _id
            cleaned_data.append(item)
        if bad_id_count:
            print(f"[warn] {bad_id_count} 条样本缺少 id/unique_id，已跳过。")

        ckpt.set_total(len(cleaned_data))

        # 3) 基于 checkpoint 的 processed_ids 进行过滤（在排序前）
        if resume:
            remaining = [it for it in cleaned_data if not ckpt.is_processed(it["_resolved_id"])]
        else:
            remaining = cleaned_data  # 完全从头开始

        if not remaining:
            print("[info] 没有剩余样本可处理。")
            return

        # 2. 按照 prompt token 长度从长到短排序（排序在切片之前）
        data = self._sort_data_by_prompt_len(remaining)
        console.print(f"[cyan]已按prompt长度排序[/cyan]")

        if self.debug_mode:
            data = data[:16]
        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")

        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]

            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")
            batch_feats_avg_with_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_avg_without_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_labels: List[int] = []
            batch_ids: List[str] = []
            batch_questions: List[str] = []
            batch_true_answers: List[str] = []
            batch_pred_answers: List[str] = []
            batch_categories: List[str] = []
            # 额外存储一个数组标记是否hit_limit
            batch_incomplete_flags: List[bool] = []

            # 标记当前批次是否完整处理完毕
            batch_completed = False

            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers, hit_limit_flags = self.forward_batch(
                    batch_items,
                    max_new_tokens=max_new_tokens,
                    needed_layers=needed_layers
                )

                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["input"].strip()
                    # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                    # console.print(f"model_answer {start_idx + i}: {answer}")

                    batch_ids.append(item.get("_resolved_id", ""))
                    batch_questions.append(q)
                    batch_true_answers.append(item["target"])
                    batch_categories.append(item["category"])
                    batch_pred_answers.append(answer)

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                        batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                        batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                    # 被截断的文本需要去掉最后一句不完整的话
                    if hit_limit_flags[i]:
                        answer = trim_incomplete_sentence(answer)

                    if item["category"] == "dyck_languages":
                        extracted_answer, is_correct = bbh_parse(answer, item["target"], item["category"], hit_limit_flags[i], question=item["input"], id=item["id"])
                    else:
                        extracted_answer, is_correct = bbh_parse(answer, item["target"], item["category"], hit_limit_flags[i], question=item["input"], id=item["id"])

                    # if item["category"] == "word_sorting":
                    #     console.print(f"\n[green]answer: {answer}[/green]")
                    # if item["category"] == "dyck_languages":
                    #     console.print(f"\n[green]answer: {answer}[/green]")

                    if extracted_answer == "Incomplete" or hit_limit_flags[i]:
                        console.print(f"\n[red]Incomplete answer[/red]")

                        if extracted_answer is None:
                            print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)
                        else:
                            if len(extracted_answer) < 500:
                                print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)
                            else:
                                print("pred_answer:", "Incomplete", "true answer:",item["target"], "is_correct:", is_correct)
                    else:
                        print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)

                    # log_cmp(extracted_answer, item["target"], is_correct)

                    batch_labels.append(int(is_correct))
                    batch_incomplete_flags.append(hit_limit_flags[i])

                # 标记批次完成
                batch_completed = True

            except Exception as e:
                console.print(f"[red]批次 {batch_idx + 1} 处理失败: {escape(str(e))}[/red]")
                console.print(f"[yellow]回退到单样本处理模式[/yellow]")

                # 仅在OOM时尝试清理
                if "out of memory" in str(e).lower():
                    gc.collect()
                    torch.cuda.empty_cache()

                # 清空当前批次已收集的数据，避免部分数据重复保存
                if batch_ids:
                    console.print(f"[yellow]清空当前批次已收集的数据，避免部分数据重复保存[/yellow]")
                    batch_feats_avg_with_prompt = {k: [] for k in needed_layers}
                    batch_feats_avg_without_prompt = {k: [] for k in needed_layers}
                    batch_feats_last_token = {k: [] for k in needed_layers}
                    batch_labels.clear()
                    batch_ids.clear()
                    batch_questions.clear()
                    batch_true_answers.clear()
                    batch_pred_answers.clear()
                    batch_categories.clear()
                    batch_incomplete_flags.clear()

                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer, hit_limit = self.forward_once(
                            item,
                            max_new_tokens=max_new_tokens,
                            needed_layers=needed_layers
                        )

                        q = item["input"].strip()
                        # console.print(f"\n[green]question {start_idx + i}: {escape(q)}[/green]")
                        # console.print(f"model_answer {start_idx + i}: {answer}", markup=False)

                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(q)
                        batch_true_answers.append(item["target"])
                        batch_categories.append(item["category"])
                        batch_pred_answers.append(answer)

                        # 添加特征
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(layer_vecs[k]["avg_with_prompt"])
                            batch_feats_avg_without_prompt[k].append(layer_vecs[k]["avg_without_prompt"])
                            batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                        # 被截断的文本需要去掉最后一句不完整的话
                        if hit_limit:
                            answer = trim_incomplete_sentence(answer)

                        if item["category"] == "dyck_languages":
                            extracted_answer, is_correct = bbh_parse(answer, item["target"], item["category"], hit_limit, question=item["input"], id=item["id"])
                        else:
                            extracted_answer, is_correct = bbh_parse(answer, item["target"], item["category"], hit_limit, question=item["input"], id=item["id"])

                        if extracted_answer == "Incomplete" or hit_limit:
                            console.print(f"\n[red]Incomplete answer[/red]")

                            if extracted_answer is None:
                                print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)
                            else:
                                if len(extracted_answer) < 500:
                                    print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)
                                else:
                                    print("pred_answer:", "Incomplete", "true answer:",item["target"], "is_correct:", is_correct)
                        else:
                            print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)

                        # if item["category"] == "word_sorting" and hit_limit == True:
                        #     console.print(f"\n[green]answer: {answer}[/green]")
                        # if item["category"] == "dyck_languages":
                        #     console.print(f"\n[green]answer: {answer}[/green]")

                        # log_cmp(extracted_answer, item["target"], is_correct)

                        batch_labels.append(int(is_correct))
                        batch_incomplete_flags.append(hit_limit)
                        # console.print("pred_answer:", extracted_answer, "true answer:",item["target"], "is_correct:", is_correct)

                    except Exception as single_e:
                        console.print(f"[red]样本 {start_idx + i} 处理失败: {escape(str(single_e))}[/red]")
                        # 添加空特征以保持索引一致
                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(item["input"].strip())
                        batch_true_answers.append(item["target"])
                        batch_pred_answers.append("Failed")
                        batch_labels.append(0)
                        batch_incomplete_flags.append(False)
                        batch_categories.append(item["category"])

                        # 添加零特征
                        hidden_dim = self.model.config.hidden_size
                        _dtype = next(self.model.parameters()).dtype
                        _zero  = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")
                        for k in needed_layers:
                            batch_feats_avg_with_prompt[k].append(_zero.clone())
                            batch_feats_avg_without_prompt[k].append(_zero.clone())
                            batch_feats_last_token[k].append(_zero.clone())

                # 单样本回退模式下，整个batch处理完成才标记完成
                batch_completed = True

            # 只有当批次完整处理完毕时才保存
            if batch_completed:
                # 保存当前批次的特征
                self._save_batch_features(
                    batch_feats_avg_with_prompt,
                    batch_feats_avg_without_prompt,
                    batch_feats_last_token,
                    batch_labels,
                    batch_ids,
                    batch_questions,
                    batch_true_answers,
                    batch_pred_answers,
                    batch_categories,
                    batch_incomplete_flags,
                    avg_with_prompt_dir,
                    avg_without_prompt_dir,
                    last_token_outputdir,
                )

                # 写入断点（即使 resume=False 也会持续写，确保中断可恢复）
                ckpt.mark_batch_processed(batch_ids)

            # 清理GPU内存
            # torch.cuda.empty_cache()

        print(f"[done] 共 {len(remaining)} 条新样本完成。累计完成 {len(ckpt.processed_ids)}/{ckpt.total_samples}.")


def load_datasets(dataset: List[str], data_dir: str) -> Dict[str, str]:
    paths: Dict[str, str] = {}
    for d in dataset:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}[/red]")
    return paths


def main() -> None:
    p = argparse.ArgumentParser("Hidden feature extractor with checkpoint")
    p.add_argument("--models_root", default=os.path.join(PROJECT_ROOT, "models"), help="本地模型根目录（用于拼接 模型名→路径）")
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT,"new_benchmark","bbh"))
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT, "feats/bbh/bbh_feats_test"))
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="all")
    p.add_argument("--dataset", default=None)
    p.add_argument("--use_few_shot", action="store_true")
    p.add_argument("--layer_type", default="all", choices=["middle", "last", "second_last", "all"])
    p.add_argument("--batch_size", type=int, default=None, help="批量处理的batch size")
    p.add_argument("--resume", action="store_true", help="是否从断点恢复")
    p.add_argument("--checkpoint_root", type=str, default=os.path.join(PROJECT_ROOT, "checkpoints/bbh/bbh_feats"), help="checkpoint存储目录",)
    # 自动化 batch 估计参数
    p.add_argument("--auto_batch", choices=["off", "p99", "max"], default="off", help="按分片自动估计 batch（p99 或 max）")
    p.add_argument("--gen_tokens", type=int, default=2048, help="估算时的最大生成长度")
    p.add_argument("--n_gpus", type=int, default=1, help="参与推理的 GPU 数")
    p.add_argument("--per_gpu_vram_gib", type=float, default=80.0, help="每张卡可用显存 GiB，用于估算")
    p.add_argument("--pct", type=float, default=0.99, help="分位数（默认 0.99，用于 pXX 计算）")
    args = p.parse_args()

    global USE_FEW_SHOT
    USE_FEW_SHOT = args.use_few_shot

    os.makedirs(args.models_root, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    if args.dataset is None:
        args.dataset = ["bbh_train", "bbh_val", "bbh_test"]
    if isinstance(args.dataset, str):
        args.dataset = [args.dataset]
    datasets = load_datasets(args.dataset, args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在[/red]")
        return

    if args.checkpoint_root is None:
        checkpoint_root = os.path.join(args.output_dir, "checkpoints")
    else:
        checkpoint_root = args.checkpoint_root
    os.makedirs(checkpoint_root, exist_ok=True)

    # models = MODELS.keys() if args.model == "all" else [args.model]
    if args.model == "all":
        models = []
        for d in sorted(os.listdir(args.models_root)):
            mp = os.path.join(args.models_root, d)
            if os.path.isdir(mp) and (
                os.path.exists(os.path.join(mp, "config.json")) or
                os.path.exists(os.path.join(mp, "tokenizer.json"))
            ):
                models.append(d)
    else:
        models = [m.strip() for m in str(args.model).split(",") if m.strip()]

    for mk in models:
        model_path = os.path.join(args.models_root, mk)
        if not os.path.exists(model_path):
            console.print(f"[red] {mk} 模型目录不存在：{model_path}[/red]")
            continue

        console.rule(f"📦  Processing model — {mk}")
        extractor = HiddenExtractor(mk, model_path, args.device)

        for dname, dpath in datasets.items():
            avg_with_prompt_dir = os.path.join(args.output_dir, mk+"_avg_with_prompt", dname)
            avg_without_prompt_dir = os.path.join(args.output_dir, mk+"_avg_without_prompt", dname)
            last_token_out_dir = os.path.join(args.output_dir, mk+"_last_token", dname)

            # 自动估算 batch
            if args.batch_size is None:
                if args.auto_batch == "off":
                    console.print(f"[yellow]警告: 未指定 --batch_size，且 --auto_batch=off，无法执行，请重新设置 batch[/yellow]")
                    return
            else:
                effective_bsz = max(1, int(args.batch_size)) # 默认使用命令行的 batch_size，防止 auto_batch 为 off 时参数未定义报错

            # 设置最大生成长度
            mk_lower = mk.lower()
            if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
                gen_tok = 8192
            else:
                gen_tok = int(args.gen_tokens)

            if args.auto_batch != "off":
                try:
                    # 计算该分片的 (mx, pxx) —— 按原始功能脚本的单分片接口来做
                    # 注意：这里直接把 dname 作为 dataset_with_split 传入，例如 "math_test"
                    mx, pxx = PTS_POST.token_calculation(
                        model_name=mk,
                        dataset_with_split=dname,
                        pct=args.pct, # 默认 0.99
                    )

                    bsz_p99, bsz_mx = VRAM_POST.batch_estimation(
                        mx=mx,
                        p99=pxx,
                        model_name=mk,
                        gen_tokens=gen_tok,
                        n_gpus=args.n_gpus,
                        per_gpu_vram_gib=args.per_gpu_vram_gib,
                    )

                    auto_bsz = bsz_p99 if args.auto_batch == "p99" else bsz_mx

                    # 若用户也传了 --batch_size，取更小的，稳妥一些
                    effective_bsz = max(1, int(auto_bsz))
                    if args.batch_size is not None:
                        effective_bsz = min(effective_bsz, int(args.batch_size))

                    # 打印选取结果
                    if args.auto_batch == "p99":
                        console.print(f"[cyan]AUTO-BATCH(p99)[/cyan] {mk}/{dname}: p{int(args.pct*100)}={pxx} → use batch_size={effective_bsz}")
                    else:
                        console.print(f"[cyan]AUTO-BATCH(max)[/cyan] {mk}/{dname}: mx={mx} → use batch_size={effective_bsz}")

                except Exception as e:
                    console.print(f"[yellow]AUTO-BATCH 估计失败 {mk}/{dname}：{e}；回退 batch_size={args.batch_size}[/yellow]")
                    if args.batch_size is not None:
                        effective_bsz = max(1, int(args.batch_size))
                    else:
                        # 兜底到 1，或你可以兜底到一个更保守的值
                        effective_bsz = 1
                    console.print(f"[yellow]使用回退 batch_size={effective_bsz}[/yellow]")

            extractor.extract_dataset(
                dname,
                dpath,
                avg_with_prompt_dir,
                avg_without_prompt_dir,
                last_token_out_dir,
                layer_req=args.layer_type,
                max_new_tokens=gen_tok,
                batch_size=effective_bsz,
                resume=args.resume,  # 传递resume参数
                checkpoint_root=checkpoint_root,
            )

    console.rule("程序结束！")
    # logf.close()


####################### 测试 #######################
def run_mcq_cases():
    console.print("\n===== MCQ (bbh_parse; EM: extracted == true) =====")
    passed = 0
    for i, c in enumerate(cases_mcq, 1):
        extracted, _ = bbh_parse(
            pred=c["pred"],
            true=c["true"],
            category=c["category"],   # 只传这三个参数
        )
        ok = (extracted == c["true"])
        console.print(f"[MCQ {i:02d}] {'PASS' if ok else 'FAIL'}", markup=False)
        console.print(f"  pred: {repr(c['pred'])}")
        console.print(f"  -> extracted={repr(extracted)} | true={repr(c['true'])}\n")
        passed += ok
    console.print(f"Summary (MCQ): {passed}/{len(cases_mcq)} passed\n")

def run_dyck_cases():
    console.print("\n===== Dyck (bbh_parse; EM: extracted == true) =====")
    passed = 0
    for i, c in enumerate(cases_dyck, 1):
        extracted, _ = bbh_parse(
            pred=c["pred"],
            true=c["true"],
            category=c["category"],   # 只传这三个参数
        )
        ok = (extracted == c["true"])
        console.print(f"[Dyck {i:02d}] {'PASS' if ok else 'FAIL'}", markup=False)
        console.print(f"  pred: {repr(c['pred'])}")
        console.print(f"  -> extracted={repr(extracted)} | true={repr(c['true'])}\n")
        passed += ok
    console.print(f"Summary (Dyck): {passed}/{len(cases_dyck)} passed\n")


if __name__ == "__main__":
    main()
    # from test import cases_mcq, cases_dyck, _sp_brackets
    # run_mcq_cases()
    # run_dyck_cases()
