from __future__ import annotations
import argparse
import json
import os
import shutil
import importlib.util
from typing import Dict, List, Optional, Tuple, Iterable, Any, Set, Literal, Union
from decimal import Decimal, InvalidOperation, getcontext

import torch
import gc
from rich.console import Console
from rich.progress import track
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging
import re
import sys
from rich.markup import escape
from dataclasses import dataclass

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

logging.set_verbosity_error()  # 仅显示ERROR信息，忽略WARNING
console = Console()

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
torch.manual_seed(42)
getcontext().prec = 50 # 设置运算精度

def parse_input(item):
    question = item["question"]
    context = f"Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{question}\n"
    return context

# ===================== 答案抽取逻辑 =====================
INVALID_ANS = "[invalid]"
Mode = Literal["marker_strict", "marker_loose", "last_line", "tail_number", "none"]

@dataclass
class ParseResult:
    value: Union[int, Decimal, str]  # int / Decimal / INVALID_ANS
    mode: Mode
    evidence: str

# 支持：-123, 1,234, 42.0, 3.5, 1e-3, -2.0E+5
NUM_TOKEN_RE = re.compile(
    r"[-+]?"
    r"(?:\d{1,3}(?:,\d{3})*|\d+)"
    r"(?:\.\d+)?"
    r"(?:[eE][-+]?\d+)?"
)

PURE_NUM_LINE_RE = re.compile(
    r"^\s*([-+]?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?(?:[eE][-+]?\d+)?)\s*\.?\s*$"
)

def _to_number(num_str: str) -> Optional[Union[int, Decimal]]:
    s = num_str.strip().replace(",", "")
    if s.endswith("."):
        s = s[:-1].strip()
    try:
        d = Decimal(s)
    except InvalidOperation:
        return None
    # 若是整数值（包括 42.0 / 4.2000），直接转 int
    if d == d.to_integral_value():
        return int(d)
    return d

def _extract_first_number(text: str) -> Optional[Union[int, Decimal]]:
    for m in NUM_TOKEN_RE.finditer(text):
        v = _to_number(m.group(0))
        if v is not None:
            return v
    return None

def _extract_last_number(text: str) -> Optional[Union[int, Decimal]]:
    last = None
    for m in NUM_TOKEN_RE.finditer(text):
        v = _to_number(m.group(0))
        if v is not None:
            last = v
    return last

MARKERS: List[Tuple[str, re.Pattern]] = [
    ("hash4", re.compile(r"(?im)^\s*####\s*(.+?)\s*$")),
    ("answer", re.compile(r"(?im)^\s*answer\s*[:：]\s*(.+?)\s*$")),
    ("final", re.compile(r"(?is)\bfinal\s*(?:answer)?\s*(?:is|=|:)?\s*([^\n\r]*)")),
    ("the_answer", re.compile(r"(?is)\bthe\s+answer\s*(?:is|=|:)\s*([^\n\r]*)")),
    ("zh_answer", re.compile(r"(?is)(?:最终答案|答案)\s*(?:是|为|=|:|：)\s*([^\n\r]*)")),
    ("boxed", re.compile(r"(?is)\\boxed\{\s*([^}]*)\s*\}")),
]

def extract_pred_answer_number(completion: str, tail_chars: int = 800) -> ParseResult:
    if not completion or not completion.strip():
        return ParseResult(INVALID_ANS, "none", "")

    text = completion.replace("\r\n", "\n").replace("\r", "\n")

    best = None  # (end_pos, mode, value, evidence)
    for name, pat in MARKERS:
        for m in pat.finditer(text):
            cand = m.group(1).strip()

            v_strict = _to_number(cand)
            if v_strict is not None:
                end = m.end()
                if best is None or end >= best[0]:
                    best = (end, "marker_strict", v_strict, f"{name}: {m.group(0).strip()}")
                continue

            v_loose = _extract_first_number(cand)
            if v_loose is not None:
                end = m.end()
                if best is None or end >= best[0]:
                    best = (end, "marker_loose", v_loose, f"{name}: {m.group(0).strip()}")
                continue

    if best is not None:
        _, mode, v, ev = best
        return ParseResult(v, mode, ev)

    # 最后一行若是纯数字（整数/小数/科学计数法都行）
    lines = [ln.strip() for ln in text.split("\n") if ln.strip()]
    if lines:
        last_line = lines[-1]
        m = PURE_NUM_LINE_RE.match(last_line)
        if m:
            v = _to_number(m.group(1))
            if v is not None:
                return ParseResult(v, "last_line", last_line)

    # 尾部兜底：尾部最后一个数字
    tail = text[-max(50, int(tail_chars)):]
    v = _extract_last_number(tail)
    if v is not None:
        return ParseResult(v, "tail_number", tail[-200:])

    return ParseResult(INVALID_ANS, "none", "")

def normalize_gold_number(gold: Union[int, float, str, Decimal]) -> Union[int, Decimal]:
    if isinstance(gold, int):
        return gold
    if isinstance(gold, Decimal):
        return int(gold) if gold == gold.to_integral_value() else gold
    if isinstance(gold, float):
        d = Decimal(str(gold))
        return int(d) if d == d.to_integral_value() else d
    if isinstance(gold, str):
        v = _to_number(gold)
        if v is None:
            raise ValueError(f"Gold not numeric: {gold!r}")
        return v
    raise TypeError(f"Unsupported gold type: {type(gold)}")

def _num_equal(g, p, tol=Decimal("1e-6")) -> bool:
    # g/p: int or Decimal
    if isinstance(g, int):
        if isinstance(p, int):
            return p == g
        return p == Decimal(g)  # 允许 3.0 == 3
    # g 是小数：容差
    gd = Decimal(g) if isinstance(g, int) else g
    pd = Decimal(p) if isinstance(p, int) else p
    return abs(pd - gd) <= tol
# =============================================================================

def gsm8k_parse(pred, true, hit_limit):
    pred_norm=extract_pred_answer_number(pred).value

    if not isinstance(pred_norm, (int, Decimal)):
        return ("Incomplete" if hit_limit else INVALID_ANS), False

    try:
        true_norm = normalize_gold_number(true)
    except Exception:
        raise RuntimeError("true answer解析错误，请检查原始文件")

    return pred_norm, _num_equal(true_norm, pred_norm, tol=Decimal("1e-6"))

# 清空当前目录下的所有文件和子目录
def clear_dir(dir_path: str):
    if not os.path.isdir(dir_path):
        return
    for filename in os.listdir(dir_path):
        file_path = os.path.join(dir_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.remove(file_path)  # 删除文件或符号链接
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)  # 删除子目录
        except Exception as e:
            print(f"删除失败: {file_path}, 原因: {e}")


def get_item_id(item: Dict[str, Any]) -> str:
    val = item.get("unique_id", None)  # MATH 用 unique_id
    if val is None or val == "":
        val = item.get("question_id")  # MMLU_PRO 用 question_id
        if val is None or val == "":
            val = item.get("id")  # BBH 和 gsm8k 用 id
    if val is None:
        return ""
    return str(val).strip()  # "0" 也会被当作有效 ID


# 自动化 batch 估计
def _load_module(module_path: str, module_name: str):
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot import {module_name} from {module_path}")
    mod = importlib.util.module_from_spec(spec)
    import sys as _sys

    _sys.modules[module_name] = mod
    spec.loader.exec_module(mod)
    return mod


PTS_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "prompt_token_states_post.py")
VRAM_PATH = os.path.join(PROJECT_ROOT, "utils", "vram_estimation", "vram_estimate_post.py")
PTS_POST = _load_module(PTS_PATH, "prompt_token_states_post")
VRAM_POST = _load_module(VRAM_PATH, "vram_estimate_post")
# 断言模块是否加载成功
assert hasattr(PTS_POST, "token_calculation"), "prompt_token_states_post 模块缺少 token_calculation"
assert hasattr(VRAM_POST, "batch_estimation"), "vram_estimate_post 模块缺少 batch_estimation"


# eos 规范化
def _normalize_to_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple, set)):
        return list(x)
    return [x]


class CheckpointManager:
    """
    将进度以“已处理样本 ID 集合”的方式持久化。
    load_existing=False 时，会忽略并覆盖既有 checkpoint，实现“从头跑但仍写断点”的模式。
    """

    def __init__(self, path: str, load_existing: bool = True):
        self.path = path
        self.processed_ids: Set[str] = set()
        self.completed: bool = False
        self.total_samples: int = 0
        if load_existing:
            self._load()
        else:
            # 显式选择不加载旧 ckpt：如有旧文件，将在第一次 _atomic_save 时被覆盖
            pass

    def _load(self) -> None:
        if os.path.exists(self.path):
            try:
                with open(self.path, "r", encoding="utf-8") as f:
                    data = json.load(f)
                self.processed_ids = set(map(str, data.get("processed_ids", [])))
                self.completed = bool(data.get("completed", False))
                self.total_samples = int(data.get("total_samples", 0))
            except Exception as e:
                print(f"[ckpt] 读取失败，忽略旧文件：{self.path} ({e})")

    def reset(self) -> None:
        """清空内存态的进度；下次 _atomic_save 将覆盖磁盘文件。"""
        self.processed_ids.clear()
        self.completed = False
        # total_samples 保留，后续 set_total 会更新
        # 不立即落盘，避免误覆盖，等到首次写入再覆盖

    def set_total(self, n: int) -> None:
        if self.total_samples != n:
            self.total_samples = int(n)
            self._atomic_save()

    def mark_batch_processed(self, ids: Iterable[str]) -> None:
        for _id in ids:
            sid = str(_id).strip() if _id is not None else ""
            if sid:
                self.processed_ids.add(sid)
        if self.total_samples > 0 and len(self.processed_ids) >= self.total_samples:
            self.completed = True
        self._atomic_save()

    def is_processed(self, _id: Any) -> bool:
        sid = str(_id).strip() if _id is not None else ""
        return bool(sid) and (sid in self.processed_ids)

    def _atomic_save(self) -> None:
        tmp = self.path + ".tmp"
        data = {
            "processed_ids": sorted(self.processed_ids),
            "completed": self.completed,
            "total_samples": self.total_samples,
        }
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(tmp, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False)
        os.replace(tmp, self.path)


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        try:
            attn_implementation = "flash_attention_2"
            if "gemma" in model_path:
                attn_implementation = "sdpa"
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                # use_flash_attn=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation=attn_implementation,
            )
        except Exception as e:
            console.print(
                f"[yellow]Warning: Failed to load model with auto device map and flash-attention: {escape(str(e))}[/yellow]"
            )
            # 防止某些模型不支持flash-attention
            try:
                # 防止某些模型不支持flash-attention
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                )
            except Exception as e2:
                console.print(
                    f"[yellow]Warning: Failed to load model with auto device map: {escape(str(e2))}[/yellow]"
                )
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.bfloat16,
                    device_map=None,  # 单设备
                ).to(self.device)

        self.model.eval()
        self.layer_map = self._build_layer_index()
        self.calculate_right_padding_length = self._calculate_right_padding_length

        # hook相关属性
        self.hooks = []
        self.captured_hidden_states = {}
        self.is_batch_mode = False
        self.current_sample_idx = 0

        # 调试属性
        self.debug_mode = False
        self.debug_info = {}

    def _build_layer_index(self) -> Dict[str, int]:
        n_layers = self.model.config.num_hidden_layers

        mapping = {
            "quarter": n_layers // 4,  # 1/4层
            "middle": n_layers // 2,
            "three_quarters": (3 * n_layers) // 4,  # 3/4层
            "last": n_layers - 1,
            "second_last": n_layers - 2,
            "first": 0
        }
        console.print(
            f"总层数(推断): {n_layers} → quarter={mapping['quarter']}, middle={mapping['middle']}, "
            f"three_quarters={mapping['three_quarters']}, second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping

    # 钩子注册
    def _register_hooks(self, needed_layers: List[str]) -> None:
        """注册forward hooks来捕获指定层的隐状态"""
        self._clear_hooks()  # 先清除之前的hooks
        self.captured_hidden_states = {k: [] for k in needed_layers}

        def create_hook(layer_name: str):
            def hook_fn(module, input, output):

                if isinstance(output, tuple):
                    hidden_state = output[0]  # 通常隐状态是第一个元素
                else:
                    hidden_state = output

                # 复制到CPU后立即删除GPU引用，但使用更安全的方式
                hidden_state_cpu = hidden_state.detach().to("cpu")

                # 显式删除GPU引用以节省显存，但要确保Flash Attention已经完成计算
                if hidden_state.is_cuda:
                    del hidden_state

                # 存储隐藏状态 - 不需要步骤计数，按顺序存储即可
                self.captured_hidden_states[layer_name].append(hidden_state_cpu)

            return hook_fn

        # 根据模型架构获取正确的层
        if hasattr(self.model, "model"):  # 如 LlamaForCausalLM
            layers = self.model.model.layers
        elif hasattr(self.model, "transformer"):  # 如 GPT类模型
            layers = self.model.transformer.h
        else:
            # 尝试其他可能的属性名
            for attr in ["layers", "h", "decoder_layers"]:
                if hasattr(self.model, attr):
                    layers = getattr(self.model, attr)
                    break
            else:
                raise AttributeError("无法找到模型的层结构")

        for layer_name in needed_layers:
            if layer_name == "first":
                layer_idx = 0
            elif layer_name == "last":
                layer_idx = len(layers) - 1
            elif layer_name == "second_last":
                layer_idx = len(layers) - 2
            elif layer_name == "middle":
                layer_idx = len(layers) // 2
            elif layer_name == "quarter":
                layer_idx = len(layers) // 4
            elif layer_name == "three_quarters":
                layer_idx = (3 * len(layers)) // 4
            else:
                layer_idx = self.layer_map.get(layer_name, 0)

            if 0 <= layer_idx < len(layers):
                hook = layers[layer_idx].register_forward_hook(create_hook(layer_name))
                self.hooks.append(hook)

    def _calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0
        pad_id = self.tokenizer.pad_token_id
        gc_eos = _normalize_to_list(
            getattr(self.model.generation_config, "eos_token_id", None)
        )
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

    def _clear_hooks(self) -> None:
        """清除所有hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        # 注意：这里不清除captured_hidden_states，因为后续还需要使用
        # self.captured_hidden_states.clear()

    # 方便统一处理chat/base模型
    def _build_input_ids(self, question: str) -> Tuple[torch.Tensor, list[int], str]:
        """
        统一构建 input_ids, 返回:
        - input_ids  : (1, seq_len) tensor on self.device
        - prompt_ids : list[int] 同一内容, 用于后续定位
        - prompt_txt : str, decode 后的完整 prompt 文本
        """
        # chat / instruct 模型：tokenizer 支持 chat template
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            prompt_struct = [{"role": "user", "content": question}]
            if "qwen3" in self.model_key.lower():
                try:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=True,
                    )
                except Exception:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
            else:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
        else:
            batch = self.tokenizer(
                question,
                add_special_tokens=True,  # 先让tokenizer自己加他心目中的特殊符号
                return_tensors="pt",
                truncation=False,
            )
            input_ids = batch["input_ids"]  # shape: (1, L)
            bos_id = self.tokenizer.bos_token_id
            if bos_id is not None and input_ids[0, 0].item() != bos_id:
                bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                input_ids = torch.cat([bos_tensor, input_ids], dim=1)

            encoded = input_ids

        embed_dev = self.model.get_input_embeddings().weight.device
        encoded = encoded.to(embed_dev)
        prompt_ids = encoded[0].tolist()
        prompt_txt = self.tokenizer.decode(encoded[0], skip_special_tokens=False)
        return encoded, prompt_ids, prompt_txt

    def _build_batch_input_ids(
        self, questions: List[str]
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[str], torch.Tensor]:
        """
        批量构建 input_ids, 返回:
        - input_ids  : (batch_size, max_seq_len) tensor on self.device
        - padded_attention_mask : 填充的 attention_mask
        - prompt_ids_list : List[List[int]] 每个样本的prompt_ids
        - prompt_txts : List[str] 每个样本的prompt文本
        - left_pad_lens : 左填充长度
        """
        self.tokenizer.padding_side = (
            "left"  # 左填充方便后续处理（直接取最后一个非pad token的隐状态即可）
        )
        # 统一设置pad_token，确保与后续处理一致
        if self.tokenizer.pad_token_id is None:
            # 添加专用的pad_token
            try:
                special_tokens_dict = {"pad_token": "<|pad|>"}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    # 将新添加的pad embedding置零
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except Exception:
                # 如果添加失败，使用eos_token作为pad_token
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print(
                    "[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]"
                )

        # 同步到 model / generation_config
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        # self.model.generation_config.eos_token_id = self.tokenizer.eos_token_id

        prompt_txts: List[str] = []
        add_special = True
        has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
        if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
            add_special = False
            for q in questions:
                messages = [{"role": "user", "content": q}]
                if "qwen3" in self.model_key.lower():
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,  # 先拿到纯文本模板展开
                        add_generation_prompt=True,
                        enable_thinking=True,
                    )
                else:
                    txt = self.tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                prompt_txts.append(txt)
        else:
            # 无 chat 模板时，直接用原问题文本
            prompt_txts = list(questions)

        batch_enc = self.tokenizer(
            prompt_txts,
            padding="longest",  # 自动对齐到本 batch 最长
            return_tensors="pt",
            return_attention_mask=True,
            add_special_tokens=add_special,  # base model 需要添加特殊符号, chat model 不需要
            truncation=False,  # 避免静默截断
        )
        padded_input_ids = batch_enc["input_ids"]
        padded_attention_mask = batch_enc["attention_mask"]

        prompt_ids_list: List[List[int]] = [
            self.tokenizer.encode(txt, add_special_tokens=add_special)
            for txt in prompt_txts
        ]

        # 转换为tensor并移到设备
        embed_dev = self.model.get_input_embeddings().weight.device
        # print("embed_dev",embed_dev)
        padded_input_ids = padded_input_ids.to(embed_dev)
        padded_attention_mask = padded_attention_mask.to(embed_dev)

        # 根据 attention_mask 填充长度
        B, T = padded_attention_mask.shape
        prompt_lens = padded_attention_mask.sum(dim=1)  # 每个样本真实长度
        left_pad_lens = T - prompt_lens

        return (
            padded_input_ids,
            padded_attention_mask,
            prompt_ids_list,
            prompt_txts,
            left_pad_lens,
        )

    # 特征提取
    def _extract_features_from_hooks(self, right_pad_len: int, needed_layers: List[str], sample_idx: int = 0) -> Dict[str, Dict[str, torch.Tensor]]:
        """从hooks收集的隐状态中提取特征"""
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in needed_layers}

        # 动态创建零张量
        hidden_dim = self.model.config.hidden_size
        _dtype = next(self.model.parameters()).dtype
        _zero = torch.zeros(hidden_dim, dtype=_dtype, device="cpu")

        for k in needed_layers:
            if k not in self.captured_hidden_states:
                continue

            layer_states = self.captured_hidden_states[k]
            if not layer_states:
                console.log("[yellow]Warn: no output tokens generated.[/yellow]")
                layer_features[k]["prompt_last_token"] = _zero.clone()
                layer_features[k]["answer_first_token"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            first_step_state = layer_states[0]

            # 计算 prompt_last_token
            if first_step_state.dim() == 3:
                # [batch, seq_len, hidden]
                if sample_idx < first_step_state.size(0) and first_step_state.size(1) > 0:
                    prompt_last = first_step_state[sample_idx, -1]      # [hidden]
                else:
                    prompt_last = _zero
            elif first_step_state.dim() == 2:
                # [seq_len, hidden]，单样本的情况
                if first_step_state.size(0) > 0:
                    prompt_last = first_step_state[-1]                  # [hidden]
                else:
                    prompt_last = _zero
            else:
                prompt_last = _zero

            layer_features[k]["prompt_last_token"] = prompt_last.clone()

            # 收集输出部分的隐状态
            output_states = []

            # 处理后续的前向传播结果
            # 每个后续的前向传播都会产生一个新的token的隐状态
            for i in range(1, len(layer_states)):
                step_state = layer_states[i]

                # 处理维度
                if step_state.dim() == 3:
                    # 批处理模式: [batch_size, seq_len, hidden_dim]
                    if sample_idx >= step_state.size(0):
                        raise RuntimeError("样本索引越界，请检查批处理逻辑")
                    step_hidden = step_state[sample_idx]  # [seq_len, hidden_dim]
                elif step_state.dim() == 2:
                    # 单样本模式: [seq_len, hidden_dim]
                    step_hidden = step_state
                else:
                    continue

                # 新生成的token的隐状态在序列的最后一个位置
                if len(step_hidden) > 0:
                    output_states.append(step_hidden[-1:])  # [1, hidden_dim]

            # console.log(f"[blue][DEBUG] 样本 {sample_idx} 层 {k} 生成 token 数: {len(output_states)}[/blue], right_pad_len={right_pad_len}, hit_limit={hit_limit}")

            if len(output_states) > right_pad_len:
                # 只取实际输出长度的隐状态
                if right_pad_len > 0:
                    output_states = output_states[:-right_pad_len]
            else:
                output_states = []

            if len(output_states) >= 1:
                output_first_step = output_states[0]    # [1, D] 或 [D]
                if output_first_step.dim() == 2 and output_first_step.size(0) > 0:
                    answer_first = output_first_step[-1]          # [D]
                elif output_first_step.dim() == 1:
                    answer_first = output_first_step
                else:
                    answer_first = _zero
                layer_features[k]["answer_first_token"] = answer_first.clone()
            else:
                layer_features[k]["answer_first_token"] = _zero.clone()

            # 计算真正最后一个 token 的隐状态
            if len(output_states) >= 1:
                normed = []
                for t in output_states:
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    normed.append(t)
                # 拼接得到 [effective_len, hidden_dim]
                output_hidden = torch.cat(normed, dim=0)
                # 3. 最后一个token的表征
                vec_last_token = output_hidden[-1]
                layer_features[k]["last_token"] = vec_last_token.clone()
            else:
                layer_features[k]["last_token"] = _zero.clone()
        return layer_features

    # 单次前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 1024,
        needed_layers: List[str] = None,
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], str, bool]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 1) 构建 prompt
        prompt = parse_input(item)
        # 2) tokenize
        input_ids, prompt_ids, prompt_txt = self._build_input_ids(prompt)

        # 确保pad_token已设置（与批量处理保持一致）
        if self.tokenizer.pad_token_id is None:
            try:
                special_tokens_dict = {"pad_token": "<|pad|>"}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[self.tokenizer.pad_token_id].zero_()
            except Exception:
                self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
                console.print("[yellow]Warn: fallback pad uses eos embedding (not zeroed).[/yellow]")

        pad_id = self.tokenizer.pad_token_id
        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 4096:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=4096，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = False
        self.current_sample_idx = 0
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            attn_mask = torch.ones_like(input_ids)
            gen_out = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
                attention_mask=attn_mask,
            )

            # 5) 答案生成解码
            # full_txt = self.tokenizer.decode(gen_out.sequences[0], skip_special_tokens=False)
            # prompt_txt = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
            # answer_txt = full_txt[len(prompt_txt) :].lstrip()

            # 6) 从hooks提取特征
            prompt_len = len(prompt_ids)
            total_sequence = gen_out.sequences[0]
            right_pad_len = self.calculate_right_padding_length(total_sequence)
            valid_total_len = len(total_sequence) - right_pad_len
            output_len = valid_total_len - prompt_len
            if output_len < 0:
                output_len = 0

            valid_sequence = (total_sequence[:-right_pad_len] if right_pad_len > 0 else total_sequence)
            answer_txt = self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True)

            # 判断是否命中最大长度限制
            hit_limit = output_len >= max_new_tokens
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            ended_with_eos = False
            if output_len > 0:
                last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                if last_gen_id in eos_ids:
                    ended_with_eos = True

            hit_limit = hit_limit and not ended_with_eos  # 仅当命中上限且未遇到EOS时才判不完整
            layer_features = self._extract_features_from_hooks(right_pad_len, needed_layers, sample_idx=0)

        except Exception as e:
            raise RuntimeError("Single sample processing failed: " + str(e))

        return layer_features, answer_txt, hit_limit

    # 批量前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 1024,
        needed_layers: List[str] = None,
    ) -> Tuple[List[Dict[str, Dict[str, torch.Tensor]]], List[str], List[bool]]:
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        self.tokenizer.padding_side = "left"
        # 1) 构建批量 prompts
        prompts = [parse_input(item) for item in items]
        # 2) 批量 tokenize（pad_token已在_build_batch_input_ids中统一设置）
        batch_input_ids, padded_attention_mask, prompt_ids_list, prompt_txts, left_pad_lens = self._build_batch_input_ids(prompts)
        # 获取pad_id（此时pad_token_id已经在_build_batch_input_ids中设置好了）
        pad_id = self.tokenizer.pad_token_id

        mk_lower = self.model_key.lower()
        if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
            if max_new_tokens != 4096:
                raise ValueError(f"模型 {self.model_key} 需要设置 max_new_tokens=4096，但当前传入的是 {max_new_tokens}")

        # 3) 注册hooks
        self.is_batch_mode = True
        self._register_hooks(needed_layers)

        # 4) 生成
        try:
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            gen_out = self.model.generate(
                input_ids=batch_input_ids,
                max_new_tokens=max_new_tokens,
                attention_mask=padded_attention_mask,
                do_sample=False,
                # eos_token_id=self.tokenizer.eos_token_id,
                eos_token_id=(eos_ids if eos_ids else None),
                pad_token_id=pad_id,
                return_dict_in_generate=True,
            )

            # 5) 批量答案解码
            batch_answers = []
            # 6) 批量提取隐状态 - 针对输出部分
            batch_layer_features = []
            hit_limit_flags = []  # 记录是否命中最大长度限制

            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))

            for i, prompt_ids in enumerate(prompt_ids_list):
                # 获取该样本的实际序列信息（考虑左右两侧padding）
                sample_sequence = gen_out.sequences[i]
                # 计算左侧padding长度
                left_pad_len = int(left_pad_lens[i].item())
                # 提取去除左侧padding的序列
                actual_sequence = sample_sequence[left_pad_len:]
                # 计算右侧padding长度
                right_pad_len = self.calculate_right_padding_length(actual_sequence)
                if right_pad_len > 0:
                    valid_sequence = actual_sequence[:-right_pad_len]
                else:
                    valid_sequence = actual_sequence
                # 计算有效序列长度（去除左右padding）
                valid_seq_len = len(valid_sequence)

                # 获取prompt长度和输出长度
                # prompt_len = len(prompt_ids)
                prompt_len = int(padded_attention_mask[i].sum().item())
                output_len = valid_seq_len - prompt_len

                # 确保prompt长度不超过有效序列长度
                if prompt_len > valid_seq_len:
                    console.print(f"[red]错误: 样本{i}的prompt长度({prompt_len})超过有效序列长度({valid_seq_len})[/red]")
                    output_len = 0  # 强制设为0，使用零向量

                # 判断是否命中最大长度限制
                ended_with_eos = False
                if output_len > 0:
                    last_gen_id = int(valid_sequence[prompt_len + output_len - 1].item())
                    if last_gen_id in eos_ids:
                        ended_with_eos = True

                hit_limit_tmp = output_len >= max_new_tokens and not ended_with_eos
                hit_limit_flags.append(hit_limit_tmp)
                # 从hooks提取特征
                layer_features = self._extract_features_from_hooks(right_pad_len, needed_layers, sample_idx=i)
                # 答案解码
                # console.print(f"\n[green]{self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True)}")
                batch_answers.append(self.tokenizer.decode(valid_sequence[prompt_len:].tolist(), skip_special_tokens=True))
                batch_layer_features.append(layer_features)
        except Exception as e:
            raise RuntimeError("Batch processing failed: " + str(e))

        return batch_layer_features, batch_answers, hit_limit_flags

    #  把原始数据按 prompt token 长度从长到短排序，尽可能避免长短序列混搭导致大批次OOM
    def _sort_data_by_prompt_len(self, data):
        """
        输入: data(List[Dict])   输出: 排序后的 data(List[Dict])
        """

        def _build_prompt_texts(items):
            # 用你上传的 parse_input 拿最终 prompt 文本
            questions = [parse_input(it) for it in items]

            # 如果 tokenizer 支持 chat template，就按你后续 generate 的方式展开
            prompt_txts = []
            add_special = True
            has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
            if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
                add_special = False
                for q in questions:
                    messages = [{"role": "user", "content": q}]
                    try:
                        if "qwen3" in str(self.model_key).lower():
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                                enable_thinking=True,
                            )
                        else:
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                            )
                    except TypeError:
                        # 老版本 tokenizer 兼容
                        txt = self.tokenizer.apply_chat_template(
                            messages, tokenize=False
                        )
                    prompt_txts.append(txt)
            else:
                # 非 chat tokenizer，直接用文本
                prompt_txts = questions

            return prompt_txts, add_special

        # 纯 CPU 分词测长度，分块以避免一次性处理太大
        LENS_CHUNK = 256
        n = len(data)
        lens = [0] * n
        for i in range(0, n, LENS_CHUNK):
            sub = data[i : i + LENS_CHUNK]
            prompt_txts, add_special = _build_prompt_texts(sub)
            enc = self.tokenizer(
                prompt_txts,
                add_special_tokens=add_special,
                truncation=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            for k, ids in enumerate(enc["input_ids"]):
                lens[i + k] = len(ids)

        # 长 → 短 排序
        order = sorted(range(n), key=lambda t: (lens[t], t), reverse=True)
        return [data[idx] for idx in order]

    def _save_batch_features(
        self,
        feats_prompt_last_token: Dict[str, List[torch.Tensor]],
        feats_answer_first_token: Dict[str, List[torch.Tensor]],
        feats_last_token: Dict[str, List[torch.Tensor]],
        labels: List[int],
        ids: List[str],
        questions: List[str],
        true_answers: List[str],
        pred_answers: List[str],
        solutions: List[str],
        incomplete_flags: List[bool],
        prompt_last_token_dir: str,
        answer_first_token_dir: str,
        last_token_outputdir: str,
    ) -> None:
        """保存当前批次的特征"""

        def save_features(feats, output_dir, des):
            for k, vec_list in feats.items():
                if not vec_list:
                    continue
                out_path = os.path.join(output_dir, f"{k}_features.pt")

                # 如果文件已存在，加载并合并
                if os.path.exists(out_path):
                    existing_data = torch.load(out_path, map_location="cpu")

                    # Id 去重，防止在 batch 处理完，pt写入但是 id 没有及时写入导致的重复处理
                    existing_ids = existing_data["ids"]
                    existing_id_set = set(existing_ids)

                    if len(vec_list) != len(ids):
                        raise ValueError(f"{k}: vec_list({len(vec_list)}) 与 ids({len(ids)}) 长度不一致")

                    # 以当前要保存的这一路特征 vec_list 为例，先把它和同索引的 meta 列表打包
                    new_records = list(
                        zip(
                            ids,
                            labels,
                            questions,
                            true_answers,
                            pred_answers,
                            solutions,
                            incomplete_flags,
                            vec_list,
                        )
                    )

                    # 过滤：仅保留还未出现过的 id
                    new_records = [r for r in new_records if r[0] not in existing_id_set]

                    if new_records:
                        n_ids, n_labels, n_qs, n_trues, n_preds, n_sols, n_incomp, n_vecs = zip(*new_records)
                        tensor_new = torch.stack(list(n_vecs)).cpu()
                        tensor = torch.cat([existing_data["features"], tensor_new], dim=0)

                        merged_ids = existing_ids + list(n_ids)
                        merged_labels = existing_data["labels"].tolist() + list(n_labels)
                        merged_questions = existing_data["questions"] + list(n_qs)
                        merged_true_answers = existing_data["true_answers"] + list(n_trues)
                        merged_pred_answers = existing_data["pred_answers"] + list(n_preds)
                        merged_solutions = existing_data["solutions"] + list(n_sols)
                        merged_incomplete = existing_data["incomplete_flags"] + list(n_incomp)
                    else:
                        # 没有新增；直接复用 existing_data
                        tensor = existing_data["features"]
                        merged_ids = existing_ids
                        merged_labels = existing_data["labels"].tolist()
                        merged_questions = existing_data["questions"]
                        merged_true_answers = existing_data["true_answers"]
                        merged_pred_answers = existing_data["pred_answers"]
                        merged_solutions = existing_data["solutions"]
                        merged_incomplete = existing_data["incomplete_flags"]
                else:
                    # 首次写入
                    tensor = torch.stack(vec_list)
                    merged_ids = ids
                    merged_labels = labels
                    merged_questions = questions
                    merged_true_answers = true_answers
                    merged_pred_answers = pred_answers
                    merged_solutions = solutions
                    merged_incomplete = incomplete_flags

                tmp = out_path + ".tmp"
                torch.save(
                    {
                        "features": tensor,
                        "labels": torch.tensor(merged_labels),
                        "ids": merged_ids,
                        "questions": merged_questions,
                        "true_answers": merged_true_answers,
                        "pred_answers": merged_pred_answers,
                        "solutions": merged_solutions,
                        "incomplete_flags": merged_incomplete,
                    },
                    tmp,
                )
                os.replace(tmp, out_path)

        save_features(feats_prompt_last_token, prompt_last_token_dir, des="prompt最后一个token的表征")
        save_features(feats_answer_first_token, answer_first_token_dir, des="answer第一个token的表征")
        save_features(feats_last_token, last_token_outputdir, des="最后一个token的表征")

    def extract_dataset(
        self,
        dataset_name: str,
        dataset_path: str,
        prompt_last_token_dir: str,
        answer_first_token_dir: str,
        last_token_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 1024,
        batch_size: int = 4,
        resume: bool = False,
        checkpoint_root: Optional[str] = None,
    ) -> None:
        os.makedirs(prompt_last_token_dir, exist_ok=True)
        os.makedirs(answer_first_token_dir, exist_ok=True)
        os.makedirs(last_token_outputdir, exist_ok=True)

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(l) for l in f]
        console.rule(f"处理数据集 → {dataset_name}")

        needed_layers: List[str] = (list(self.layer_map.keys()) if layer_req == "all" else [layer_req])
        checkpoint_dir = os.path.join(os.path.normpath(checkpoint_root), self.model_key, dataset_name)

        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
        ckpt = CheckpointManager(checkpoint_path, load_existing=resume)

        if not resume:
            ckpt.reset()  # 非断点模式直接从头跑
            # 同时删除历史断点文件和特征文件
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            clear_dir(prompt_last_token_dir)
            clear_dir(answer_first_token_dir)
            clear_dir(last_token_outputdir)

        # 2) 预清洗 + 统计无效 ID
        cleaned_data: List[Dict[str, Any]] = []
        bad_id_count = 0
        for item in data:
            _id = get_item_id(item)
            if _id == "":
                bad_id_count += 1
                continue
            # 统一id描述
            item["_resolved_id"] = _id
            cleaned_data.append(item)
        if bad_id_count:
            print(f"[warn] {bad_id_count} 条样本缺少 id/unique_id，已跳过。")

        ckpt.set_total(len(cleaned_data))

        # 3) 基于 checkpoint 的 processed_ids 进行过滤（在排序前）
        if resume:
            remaining = [it for it in cleaned_data if not ckpt.is_processed(it["_resolved_id"])]
        else:
            remaining = cleaned_data  # 完全从头开始

        if not remaining:
            print("[info] 没有剩余样本可处理。")
            return

        # 2. 按照 prompt token 长度从长到短排序（排序在切片之前）
        data = self._sort_data_by_prompt_len(remaining)
        console.print("[cyan]已按prompt长度排序[/cyan]")

        if self.debug_mode:
            data = data[:16]
        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")

        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]

            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")

            # 用于存储当前批次的特征和标签
            batch_feats_prompt_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_answer_first_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_labels: List[int] = []
            batch_ids: List[str] = []
            batch_questions: List[str] = []
            batch_true_answers: List[str] = []
            batch_pred_answers: List[str] = []
            batch_solutions: List[str] = []
            batch_incomplete_flags: List[bool] = []

            # 标记当前批次是否完整处理完毕
            batch_completed = False

            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers, hit_limit_flags = self.forward_batch(
                    batch_items,
                    max_new_tokens=max_new_tokens,
                    needed_layers=needed_layers,
                )

                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["question"].strip()
                    batch_ids.append(item["_resolved_id"])  # 统一用 resolve_id，MATH → unique_id；MMLU_pro → question_id；BBH → id
                    batch_questions.append(q)
                    batch_true_answers.append(item["answer"])
                    batch_solutions.append(item["solution"])
                    batch_incomplete_flags.append(hit_limit_flags[i])
                    batch_pred_answers.append(answer)

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        batch_feats_prompt_last_token[k].append(layer_vecs[k]["prompt_last_token"])
                        batch_feats_answer_first_token[k].append(layer_vecs[k]["answer_first_token"])
                        batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                    extracted_answer, is_correct = gsm8k_parse(answer, item["answer"], hit_limit_flags[i])

                    if extracted_answer == "Incomplete" or hit_limit_flags[i]:
                        console.print("\n[red]Incomplete answer[/red]")


                    print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct,)
                    batch_labels.append(int(is_correct))

                # 标记批次完成
                batch_completed = True

            except Exception as e:
                console.print(f"[red]批次 {batch_idx + 1} 处理失败: {escape(str(e))}[/red]")
                console.print("[yellow]回退到单样本处理模式[/yellow]")

                # 仅在OOM时尝试清理
                if "out of memory" in str(e).lower():
                    gc.collect()
                    torch.cuda.empty_cache()

                # 清空当前批次已收集的数据，避免部分数据重复保存
                if batch_ids:
                    console.print("[yellow]清空当前批次已收集的数据，避免部分数据重复保存[/yellow]")
                    batch_feats_prompt_last_token = {k: [] for k in needed_layers}
                    batch_feats_answer_first_token = {k: [] for k in needed_layers}
                    batch_feats_last_token = {k: [] for k in needed_layers}
                    batch_labels.clear()
                    batch_ids.clear()
                    batch_questions.clear()
                    batch_true_answers.clear()
                    batch_pred_answers.clear()
                    batch_solutions.clear()
                    batch_incomplete_flags.clear()

                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer, hit_limit = self.forward_once(
                            item,
                            max_new_tokens=max_new_tokens,
                            needed_layers=needed_layers,
                        )

                        q = item["question"].strip()
                        batch_ids.append(item["_resolved_id"])
                        batch_questions.append(q)
                        batch_true_answers.append(item["answer"])
                        batch_solutions.append(item["solution"])
                        batch_incomplete_flags.append(hit_limit)
                        batch_pred_answers.append(answer)

                        # 添加特征
                        for k in needed_layers:
                            batch_feats_prompt_last_token[k].append(layer_vecs[k]["prompt_last_token"])
                            batch_feats_answer_first_token[k].append(layer_vecs[k]["answer_first_token"])
                            batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                        extracted_answer, is_correct = gsm8k_parse(answer, item["answer"], hit_limit)

                        if extracted_answer == "Incomplete" or hit_limit:
                            console.print("\n[red]Incomplete answer[/red]")

                        print("pred_answer:", extracted_answer, "true answer:", item["answer"], "is_correct:", is_correct,)

                        batch_labels.append(int(is_correct))

                    except Exception as single_e:
                        raise RuntimeError(f"样本 {start_idx + i} 处理失败: {escape(str(single_e))}[/red]")

                # 单样本回退模式下，整个batch处理完成才标记完成
                batch_completed = True

            # 只有当批次完整处理完毕时才保存
            if batch_completed:
                # 保存当前批次的特征
                self._save_batch_features(
                    batch_feats_prompt_last_token,
                    batch_feats_answer_first_token,
                    batch_feats_last_token,
                    batch_labels,
                    batch_ids,
                    batch_questions,
                    batch_true_answers,
                    batch_pred_answers,
                    batch_solutions,
                    batch_incomplete_flags,
                    prompt_last_token_dir,
                    answer_first_token_dir,
                    last_token_outputdir,
                )

                # 写入断点（即使 resume=False 也会持续写，确保中断可恢复）
                ckpt.mark_batch_processed(batch_ids)

            # 清理GPU内存
            # torch.cuda.empty_cache()

        print(f"[done] 共 {len(remaining)} 条新样本完成。累计完成 {len(ckpt.processed_ids)}/{ckpt.total_samples}.")

def load_datasets(dataset: List[str], data_dir: str) -> Dict[str, str]:
    paths: Dict[str, str] = {}
    for d in dataset:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}[/red]")
    return paths


def main() -> None:
    p = argparse.ArgumentParser("Hidden feature extractor with checkpoint")
    p.add_argument("--models_root", default=os.path.join(PROJECT_ROOT, "models"), help="本地模型根目录（用于拼接 模型名→路径）",)
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT, "new_benchmark", "gsm8k"))
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT, "feats/gsm8k/gsm8k_feats"))
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="all")
    p.add_argument("--dataset", default=None)
    p.add_argument("--layer_type", default="all", choices=["quarter", "middle", "three_quarters", "last", "second_last", "first", "all"])
    p.add_argument("--batch_size", type=int, default=None, help="批量处理的batch size")
    p.add_argument("--resume", action="store_true", help="是否从断点恢复")
    p.add_argument("--checkpoint_root", type=str, default=os.path.join(PROJECT_ROOT, "checkpoints/gsm8k/gsm8k_feats"), help="checkpoint存储目录",)
    # 自动化 batch 估计参数
    p.add_argument("--auto_batch", choices=["off", "p99", "max"], default="max", help="按分片自动估计 batch（p99 或 max）",)
    p.add_argument("--gen_tokens", type=int, default=1024, help="估算时的最大生成长度")
    p.add_argument("--n_gpus", type=int, default=1, help="参与推理的 GPU 数")
    p.add_argument("--per_gpu_vram_gib", type=float, default=80.0, help="每张卡可用显存 GiB，用于估算",)
    p.add_argument("--pct", type=float, default=0.99, help="分位数（默认 0.99，用于 pXX 计算）")
    args = p.parse_args()

    os.makedirs(args.models_root, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    if args.dataset is None:
        args.dataset = ["gsm8k_train", "gsm8k_val", "gsm8k_test"]
    if isinstance(args.dataset, str):
        args.dataset = [args.dataset]
    datasets = load_datasets(args.dataset, args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在[/red]")
        return

    if args.checkpoint_root is None:
        checkpoint_root = os.path.join(args.output_dir, "checkpoints")
    else:
        checkpoint_root = args.checkpoint_root
    os.makedirs(checkpoint_root, exist_ok=True)

    if args.model == "all":
        models = []
        for d in sorted(os.listdir(args.models_root)):
            mp = os.path.join(args.models_root, d)
            if os.path.isdir(mp) and (
                os.path.exists(os.path.join(mp, "config.json"))
                or os.path.exists(os.path.join(mp, "tokenizer.json"))
            ):
                models.append(d)
    else:
        models = [m.strip() for m in str(args.model).split(",") if m.strip()]

    for mk in models:
        model_path = os.path.join(args.models_root, mk)

        if not os.path.exists(model_path):
            console.print(f"[red] {mk} 模型目录不存在：{model_path}[/red]")
            continue

        console.rule(f"📦  Processing model — {mk}")
        extractor = HiddenExtractor(mk, model_path, args.device)

        for dname, dpath in datasets.items():
            prompt_last_token_dir = os.path.join(args.output_dir, mk+"_prompt_last_token", dname)
            answer_first_token_dir = os.path.join(args.output_dir, mk+"_answer_first_token", dname)
            last_token_out_dir = os.path.join(args.output_dir, mk+"_last_token", dname)

            # 自动估算 batch
            if args.batch_size is None:
                if args.auto_batch == "off":
                    console.print("[yellow]警告: 未指定 --batch_size，且 --auto_batch=off，无法执行，请重新设置 batch[/yellow]")
                    return
            else:
                effective_bsz = max(1, int(args.batch_size))  # 默认使用命令行的 batch_size，防止 auto_batch 为 off 时参数未定义报错

            # 设置最大生成长度
            mk_lower = mk.lower()
            if any(k in mk_lower for k in ["qwen3", "distill", "qwq"]):
                gen_tok = 4096
            else:
                gen_tok = int(args.gen_tokens)

            if args.auto_batch != "off":
                try:
                    # 计算该分片的 (mx, pxx) —— 按原始功能脚本的单分片接口来做
                    # 注意：这里直接把 dname 作为 dataset_with_split 传入，例如 "math_test"
                    mx, pxx = PTS_POST.token_calculation(
                        model_name=mk,
                        dataset_with_split=dname,
                        pct=args.pct,  # 默认 0.99
                    )

                    bsz_p99, bsz_mx = VRAM_POST.batch_estimation(
                        mx=mx,
                        p99=pxx,
                        model_name=mk,
                        gen_tokens=gen_tok,
                        n_gpus=args.n_gpus,
                        per_gpu_vram_gib=args.per_gpu_vram_gib,
                    )

                    auto_bsz = bsz_p99 if args.auto_batch == "p99" else bsz_mx

                    # 若用户也传了 --batch_size，取更小的，稳妥一些
                    effective_bsz = max(1, int(auto_bsz))
                    if args.batch_size is not None:
                        effective_bsz = min(effective_bsz, int(args.batch_size))

                    # 打印选取结果
                    if args.auto_batch == "p99":
                        console.print(f"[cyan]AUTO-BATCH(p99)[/cyan] {mk}/{dname}: p{int(args.pct*100)}={pxx} → use batch_size={effective_bsz}")
                    else:
                        console.print(f"[cyan]AUTO-BATCH(max)[/cyan] {mk}/{dname}: mx={mx} → use batch_size={effective_bsz}")

                except Exception as e:
                    console.print(f"[yellow]AUTO-BATCH 估计失败 {mk}/{dname}：{e}；回退 batch_size={args.batch_size}[/yellow]")
                    if args.batch_size is not None:
                        effective_bsz = max(1, int(args.batch_size))
                    else:
                        # 兜底到 1，或你可以兜底到一个更保守的值
                        effective_bsz = 1
                    console.print(f"[yellow]使用回退 batch_size={effective_bsz}[/yellow]")

            print("真正使用的batch: ", effective_bsz, "max_new_tokens: ", gen_tok)
            extractor.extract_dataset(
                dataset_name=dname,
                dataset_path=dpath,
                prompt_last_token_dir=prompt_last_token_dir,
                answer_first_token_dir=answer_first_token_dir,
                last_token_outputdir=last_token_out_dir,
                layer_req=args.layer_type,
                max_new_tokens=gen_tok,
                batch_size=effective_bsz,
                resume=args.resume,
                checkpoint_root=checkpoint_root,
            )

    console.rule("程序结束！")

if __name__ == "__main__":
    main()
