from typing import Any, Dict, List, Optional, Tuple, Set
import os
os.environ["TORCH_COMPILE_DISABLE"] = "1"

import torch
import torch.nn as nn
import json
import gc
from rich.console import Console
from rich.markup import escape
from rich.progress import track
import sys

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoProcessor,
    AutoImageProcessor,
    AutoModelForCausalLM,
    AutoModelForVision2Seq,
    AutoModel,
)
import transformers as _tf

from backends import create_backend
from utils import (
    parse_input,
    get_options_list,
    get_option_answer_text,
    seedbench_parse,
    trim_incomplete_sentence,
    _normalize_to_list,
    split_model_internvl,
    split_model_internvl3,
    CheckpointManager,
    clear_dir,
    get_item_id,
)

try:
    from transformers import (
        CLIPImageProcessor,
        LlavaNextProcessor,
        LlavaNextForConditionalGeneration
    )
except Exception:
    pass

def _ensure_transformers_pytorch_utils_shims():
    """
    MiniCPM-V / 一些 remote-code 模型会从 transformers.pytorch_utils import
    is_torch_greater_or_equal_than_1_13，但 transformers 新版本已移除该符号，
    导致 trust_remote_code 动态加载直接 ImportError。
    这里提前给 pytorch_utils 补上同名函数，保证 import 能过。
    """
    try:
        import re
        import torch
        import transformers.pytorch_utils as ptu
    except Exception:
        return

    def _ver_tuple(v: str):
        # 例如 "2.1.0+cu121" -> (2,1,0)
        s = (v or "").split("+")[0]
        m = re.match(r"^\s*(\d+)\.(\d+)(?:\.(\d+))?", s)
        if not m:
            return (0, 0, 0)
        return (int(m.group(1)), int(m.group(2)), int(m.group(3) or 0))

    def _mk(ver: str):
        tgt = _ver_tuple(ver)
        def _fn():
            return _ver_tuple(torch.__version__) >= tgt
        return _fn

    shims = [
        ("is_torch_greater_or_equal_than_1_13", "1.13.0"),
        ("is_torch_greater_or_equal_than_2_0", "2.0.0"),
    ]
    for name, ver in shims:
        if not hasattr(ptu, name):
            setattr(ptu, name, _mk(ver))

def _norm_cuda(dev: str) -> str:
    # "cuda" -> "cuda:0"
    if dev == "cuda":
        return "cuda:0"
    return dev


# Add parent directory to path for imports
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

console = Console()


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")
        _ensure_transformers_pytorch_utils_shims()

        # 1) 加 tokenizer
        if "vila" in model_key.lower():
            # 兼容两种情况：
            # 1) 本地路径: /.../VILA1.5-3b
            # 2) HF repo:  Efficient-Large-Model/VILA1.5-3b
            llm_path_local = os.path.join(model_path, "llm")
            if os.path.isdir(llm_path_local):
                llm_path = llm_path_local              # 本地目录
            else:
                llm_path = model_path.rstrip("/") + "/llm"  # HF 子目录

            self.tokenizer = AutoTokenizer.from_pretrained(
                llm_path,
                # LLaMA 一般是 slow tokenizer（sentencepiece），不强制 fast 版本：
                use_fast=False,
                trust_remote_code=True,
            )
        # ---- 其他模型正常走 AutoTokenizer + remote code ----
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_path,
                trust_remote_code=True,
                padding_side="left"
            )

        # 确保 tokenizer 有 padding token
        self._setup_pad_token()

        # 2) 加 config + model（复用你现在的多重尝试逻辑）
        self.config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        # self.model = self._load_model(model_path, self.config)

        if "llama3" in model_path and "llava-next" in model_path:
            self.model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
        else:
            self.model = self._load_model(model_path)

        if getattr(self.model, "hf_device_map", None) is None:
            self.model.to(self.device)

        # self.model.to(self.device)
        self.model.eval()

        # 3) processor（如果是多模态模型，用于 backend）
        if "qtunevl" in model_key.lower():
            if "2b" in model_key.lower():
                self.processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
            else:
                self.processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
                self.tokenizer = self.processor.tokenizer
        elif "llama3" in model_path and "llava-next" in model_path:
            self.processor = LlavaNextProcessor.from_pretrained(model_path, trust_remote_code=True)
        elif "wepoints/points-" in model_key.lower():
            # POINTS 官方示例更像“图像processor + model.chat”
            self.processor = CLIPImageProcessor.from_pretrained(model_path)
        else:
            try:
                self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
            except Exception:
                self.processor = None

        if "molmo" in model_key.lower() and hasattr(self.processor, "tokenizer"):
            self.tokenizer = self.processor.tokenizer
            self._setup_pad_token()

        if "vila" in model_key.lower():
            self.tokenier = self.model.tokenizer

        # lava/sharegpt4v 兜底修复
        if getattr(self.config, "model_type", "").lower() == "llava" or "sharegpt4v" in model_key.lower():
            self._fix_llava_processor()

        # 4) backend
        self.backend = create_backend(
            model_key=model_key,
            model=self.model,
            tokenizer=self.tokenizer,
            processor=self.processor,
            device=device,
        )

        # 5) 定位文本骨干层 + layer 映射
        self.text_layers = self._locate_text_layers()
        self.layer_map = self._build_layer_index()
        self.hooks = []
        self.captured_hidden_states = {}

        # 6) 预计算隐藏维度和零张量，避免重复计算
        self.hidden_dim = self._get_hidden_dim()
        self.hidden_zero = torch.zeros(self.hidden_dim, dtype=next(self.model.parameters()).dtype, device="cpu")

        # 兼容原始代码的标志
        self.is_batch_mode = False
        self.debug_mode = False
        self.debug_info = {}

        console.print(f"[green]文本骨干层数: {len(self.text_layers)}[/green]")
        console.print(f"[green]隐藏维度: {self.hidden_dim}[/green]")

    def _fix_llava_processor(self):
        """
        兜底修复 LLaVA/ShareGPT4V 一类 processor 缺关键字段导致的
        `height // processor.patch_size` 报错。
        """
        proc = getattr(self, "processor", None)
        if proc is None:
            return

        # 让 processor 用同一个 tokenizer（有些 processor 里 tokenizer=None）
        if getattr(proc, "tokenizer", None) is None and getattr(self, "tokenizer", None) is not None:
            try:
                proc.tokenizer = self.tokenizer
            except Exception:
                pass

        cfg = getattr(self.model, "config", None)

        # patch_size：优先从 model.config.vision_config.patch_size 补齐
        ps = getattr(proc, "patch_size", None)
        if ps is None:
            ps = getattr(getattr(cfg, "vision_config", None), "patch_size", None)
            if ps is None:
                # 兜底：ShareGPT4V/LLaVA-1.5 常见为 14
                ps = 14
            try:
                proc.patch_size = int(ps)
            except Exception:
                pass

        # vision_feature_select_strategy：从 config 补（ShareGPT4V config 里是 "default"）
        if getattr(proc, "vision_feature_select_strategy", None) is None:
            vfs = getattr(cfg, "vision_feature_select_strategy", None)
            if vfs is None:
                vfs = "default"
            try:
                proc.vision_feature_select_strategy = vfs
            except Exception:
                pass

        # num_additional_tokens：CLIP 通常有 1 个 [CLS] 额外 token（HF 官方人员也这么建议）
        if getattr(proc, "num_additional_tokens", None) is None:
            try:
                proc.num_additional_tokens = 1
            except Exception:
                pass

    def _setup_pad_token(self):
        """确保tokenizer有padding token"""
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                try:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    console.print("[yellow]警告: tokenizer缺少pad_token，已设置为eos_token[/yellow]")
                except ValueError as e:
                    console.print(f"[yellow]警告: 无法设置pad_token ({e})，将使用eos_token_id作为pad_token_id[/yellow]")
                    if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
                        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            else:
                try:
                    self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                    console.print("[yellow]警告: tokenizer缺少pad_token和eos_token，已添加[PAD][/yellow]")
                except ValueError as e:
                    raise RuntimeError(f"[yellow]警告: 无法添加pad_token ({e})，将使用0作为pad_token_id[/yellow]")

    def _load_model(self, model_path: str):
        """加载模型，尝试多种方式"""
        AutoModelForImageTextToText = getattr(_tf, "AutoModelForImageTextToText", None)
        MllamaForConditionalGeneration = getattr(_tf, "MllamaForConditionalGeneration", None)

        attn_impl = "flash_attention_2"
        if "gemma" in model_path.lower():
            attn_impl = "sdpa"

        console.print(f"[yellow][DEBUG] 使用注意力实现: {attn_impl}[/yellow]")

        model = None
        last_err = None

        # === 仅对 R-4B 改 dtype：bf16 -> fp16/fp32 ===
        mk = f"{self.model_key} {model_path}".lower()

        # 允许用环境变量控制“默认是否强制单卡”
        #   FORCE_SINGLE_GPU=1  -> 强制单卡
        #   FORCE_SINGLE_GPU=0  -> 允许 auto（默认）
        force_single_gpu = bool(int(os.environ.get("FORCE_SINGLE_GPU", "0")))

        # MiniCPM-V 系列：必须强制单卡加载。
        # 否则 device_map="auto" 可能把 llm/vision 切到 cuda:0/cuda:1，
        # 但 backend 又会把输入统一搬到 self.device，导致 scatter 报 cuda:0/cuda:1 混用。
        if any(k in mk for k in ["minicpm-v", "openbmb/minicpm-v", "minicpm"]):
            device_map = {"": _norm_cuda(self.device)}
        elif "internvl" in model_path.lower():
            if "internvl3" in model_path.lower():
                device_map = split_model_internvl3(model_path)
            else:
                device_map = split_model_internvl(model_path)
        else:
            device_map = {"": _norm_cuda(self.device)} if force_single_gpu else "auto"

        # dtype 策略
        if "r-4b" in mk or "yannqi/r-4b" in mk:
            torch_dtype = torch.float32
        elif "omchat" in mk:
            torch_dtype = torch.float16
        else:
            torch_dtype = torch.bfloat16

        # ------------------------------------------------------------
        # ✅ 关键修复：Llama-3.2 Vision（model_type=mllama）不能让 AutoModelForCausalLM 抢先加载成 text-only
        # ------------------------------------------------------------
        cfg_model_type = str(getattr(self.config, "model_type", "")).lower()
        is_mllama = (cfg_model_type == "mllama") or (("llama-3.2" in mk) and ("vision" in mk))

        if is_mllama:
            if MllamaForConditionalGeneration is None:
                raise RuntimeError(
                    "检测到 Llama-3.2 Vision / mllama，但当前 transformers 没有 MllamaForConditionalGeneration。"
                    "请升级 transformers（一般 >=4.45 才有）。"
                )

            # 优先用 MllamaForConditionalGeneration 加载
            try:
                model = MllamaForConditionalGeneration.from_pretrained(
                    model_path,
                    trust_remote_code=True,  # 有些环境里带着也不影响；不带也行
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                    attn_implementation=attn_impl,
                )
                console.print("[green]MllamaForConditionalGeneration (with attn_impl) 成功![/green]")
            except Exception as e:
                last_err = e
                try:
                    model = MllamaForConditionalGeneration.from_pretrained(
                        model_path,
                        trust_remote_code=True,
                        torch_dtype=torch_dtype,
                        device_map=device_map,
                    )
                    console.print("[green]MllamaForConditionalGeneration (no attn_impl) 成功![/green]")
                except Exception as e2:
                    last_err = e2
                    model = None

            if model is not None:
                return model

        # 1) 先尝试 AutoModelForCausalLM
        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch_dtype,
                device_map=device_map,
                attn_implementation=attn_impl,
            )
            console.print("[green]AutoModelForCausalLM (with attn_impl) 成功![/green]")
        except Exception as e:
            try:
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                )
                console.print("[green]AutoModelForCausalLM (no attn_impl) 成功![/green]")
            except Exception as e:
                last_err = e
                model = None
        # print(last_err)

        # 2) Vision2Seq
        if model is None:
            try:
                model = AutoModelForVision2Seq.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                )
                console.print("[green]AutoModelForVision2Seq 成功![/green]")
            except Exception as e:
                last_err = e

        # 3) ImageTextToText
        if model is None and AutoModelForImageTextToText is not None:
            try:
                model = AutoModelForImageTextToText.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                )
                console.print("[green]AutoModelForImageTextToText 成功![/green]")
            except Exception as e:
                last_err = e

        # 4) AutoModel 兜底
        if model is None:
            try:
                model = AutoModel.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                )
                console.print("[green]AutoModel 兜底成功![/green]")
            except Exception as e:
                last_err = e

        if model is None:
            raise last_err or RuntimeError("所有模型加载方式都失败了")

        return model

    def _get_hidden_dim(self) -> int:
        """
        兼容各种 VLM config 的隐藏维度获取：
        - 先看 config.hidden_size / text_hidden_size / llm_hidden_size
        - 再看嵌套的 text_config.hidden_size 等
        - 最后从模型参数 shape 推断
        """
        cfg = getattr(self.model, "config", None)

        # 1) 直接字段
        for name in ["hidden_size", "text_hidden_size", "llm_hidden_size"]:
            if cfg is not None and hasattr(cfg, name):
                val = getattr(cfg, name)
                if isinstance(val, int):
                    return val

        # 2) 嵌套的 text_config / language_config / llm_config
        for sub_name in ["text_config", "language_config", "llm_config"]:
            sub_cfg = getattr(cfg, sub_name, None)
            if sub_cfg is not None and hasattr(sub_cfg, "hidden_size"):
                val = getattr(sub_cfg, "hidden_size")
                if isinstance(val, int):
                    return val

        # 3) 从参数推断（取第一个 >=2 维参数的最后一维）
        for p in self.model.parameters():
            if p.ndim >= 2:
                return int(p.shape[-1])

        # 4) 从模型参数量和常见配置启发式估计
        param_count = sum(p.numel() for p in self.model.parameters())
        if param_count > 0:
            # 根据参数量估计 hidden_dim：3B→2048, 7B→4096, 14B→5120, 32B→7168, 72B→8192
            param_b = param_count / 1e9
            if param_b <= 4:  # ~3B models
                return 2048
            elif param_b <= 10:  # ~7B models
                return 4096
            elif param_b <= 20:  # ~14B models
                return 5120
            elif param_b <= 40:  # ~32B models
                return 7168
            else:  # ~70B+ models
                return 8192

        # 实在不行给个保守默认值
        return 4096

    # ====== 文本层定位 ======
    def _locate_text_layers(self):
        """
        自动在常见 VLM 包装结构里找出"文本 Transformer 层"列表。
        只在这些层上挂 hook，避免 hook 到视觉编码器。
        """
        m = self.model

        # 0) Molmo 等模型的“精准路径”优先（你手动确认的就是这个）
        for p in ("model.transformer.blocks", "transformer.blocks"):
            try:
                layers = m.get_submodule(p)
                _ = layers[0]          # 确认可索引
                console.print(f"[cyan]检测到文本层(优先路径): {p} (len={len(layers)})[/cyan]")
                return layers
            except Exception:
                pass

        # 常见包裹：InternVL / Qwen-VL / LLaVA / XComposer / MiniCPM-V 等
        submods = [
            m,
            getattr(m, "model", None),
            getattr(m, "language_model", None),
            getattr(m, "llm", None),
            getattr(m, "text_model", None),
            getattr(m, "decoder", None),
            getattr(m, "transformer", None),
        ]

        seen = set()
        for sub in submods:
            if sub is None or id(sub) in seen:
                continue
            seen.add(id(sub))
            for attr in ["layers", "h", "decoder_layers", "block"]:
                layers = getattr(sub, attr, None)
                if layers is None:
                    continue
                # 尝试访问第 0 层判断是不是类似 ModuleList 的结构
                try:
                    _ = layers[0]
                except Exception:
                    pass
                console.print(f"[cyan]检测到文本层: {sub.__class__.__name__}.{attr} (len={len(layers)})[/cyan]")
                return layers

        # ===== 2) 深度 DFS fallback，处理 language_model.model.layers 等情况 =====
        console.print("[yellow]浅层搜索未找到文本层，尝试做有限深度 DFS 搜索...[/yellow]")

        # 候选列表: (parent_module, attr_name, layers, depth, path)
        candidates = []

        def is_valid_layer_list(l):
            if not isinstance(l, (nn.ModuleList, list, tuple)):
                return False
            if len(l) == 0:
                return False
            if not isinstance(l[0], nn.Module):
                return False
            return True

        def dfs(mod: nn.Module, depth: int, max_depth: int, path: str):
            if not isinstance(mod, nn.Module):
                return
            if depth > max_depth:
                return

            # 在当前模块上找 layers/h/decoder_layers/block
            for attr in ["layers", "h", "decoder_layers", "block"]:
                layers = getattr(mod, attr, None)
                if layers is None:
                    continue
                if not is_valid_layer_list(layers):
                    continue
                candidates.append((mod, attr, layers, depth, f"{path}.{attr}"))

            # 继续向下递归
            for name, child in mod.named_children():
                # 粗略跳过视觉分支，避免误把 vision_tower 的 block 当成文本层
                lname = name.lower()
                if any(k in lname for k in ["vision", "visual", "clip_image", "vit", "resnet", "image"]):
                    continue
                dfs(child, depth + 1, max_depth, f"{path}.{name}")

        # 优先从更像“文本骨干”的子模块开始 DFS
        dfs_roots = [
            ("language_model", getattr(m, "language_model", None)),
            ("llm", getattr(m, "llm", None)),
            ("text_model", getattr(m, "text_model", None)),
            ("decoder", getattr(m, "decoder", None)),
            ("transformer", getattr(m, "transformer", None)),
            ("model", getattr(m, "model", None)),
            ("self", m),
        ]

        visited: Set[int] = set()
        for name, root in dfs_roots:
            if root is None or id(root) in visited:
                continue
            visited.add(id(root))
            dfs(root, depth=0, max_depth=4, path=name)

        if not candidates:
            raise RuntimeError("无法自动定位文本 Transformer 层，请手动修改 _locate_text_layers 适配该模型结构。")

        # 根据 config.num_hidden_layers 选择最合理的候选；
        # 没有的话就选层数最多、同时深度尽量小的
        try:
            target_n = getattr(self.model.config, "num_hidden_layers", None)
        except Exception:
            target_n = None

        if target_n is not None:
            candidates.sort(key=lambda x: (abs(len(x[2]) - target_n), x[3]))
        else:
            candidates.sort(key=lambda x: (-len(x[2]), x[3]))

        mod, attr, layers, depth, path = candidates[0]
        console.print(
            f"[cyan]DFS 检测到文本层: {mod.__class__.__name__}.{attr} "
            f"(len={len(layers)}, depth={depth}, path={path})[/cyan]"
        )
        return layers

    def _build_layer_index(self) -> Dict[str, int]:
        try:
            n_layers = getattr(self.model.config, "num_hidden_layers", None)
            if n_layers is None:
                n_layers = len(self.text_layers)
        except Exception:
            n_layers = len(self.text_layers)

        mapping = {
            "quarter": n_layers // 4,  # 1/4层
            "middle": n_layers // 2,
            "three_quarters": (3 * n_layers) // 4,  # 3/4层
            "last": n_layers - 1,
            "second_last": n_layers - 2,
            "first": 0
        }
        console.print(
            f"总层数(推断): {n_layers} → quarter={mapping['quarter']}, middle={mapping['middle']}, "
            f"three_quarters={mapping['three_quarters']}, second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping

    # 钩子注册
    def _register_hooks(self, needed_layers: List[str]) -> None:
        """只在文本骨干 self.text_layers 上挂 hook"""
        self._clear_hooks()
        self.captured_hidden_states = {k: [] for k in needed_layers}

        def create_hook(layer_name: str):
            def hook_fn(module, input, output):
                if isinstance(output, tuple):
                    hidden_state = output[0]
                else:
                    hidden_state = output
                hidden_cpu = hidden_state.detach().to("cpu")
                if hidden_state.is_cuda:
                    del hidden_state
                self.captured_hidden_states[layer_name].append(hidden_cpu)

            return hook_fn

        layers = self.text_layers
        for layer_name in needed_layers:
            if layer_name == "first":
                layer_idx = 0
            elif layer_name == "last":
                layer_idx = len(layers) - 1
            elif layer_name == "second_last":
                layer_idx = len(layers) - 2
            elif layer_name == "middle":
                layer_idx = len(layers) // 2
            elif layer_name == "quarter":
                layer_idx = len(layers) // 4
            elif layer_name == "three_quarters":
                layer_idx = (3 * len(layers)) // 4
            else:
                layer_idx = self.layer_map.get(layer_name, 0)

            if 0 <= layer_idx < len(layers):
                hook = layers[layer_idx].register_forward_hook(create_hook(layer_name))
                self.hooks.append(hook)

    def _clear_hooks(self) -> None:
        for h in self.hooks:
            h.remove()
        self.hooks.clear()

    # 特征提取
    def _extract_features_from_hooks(self, right_pad_len: int, needed_layers: List[str], sample_idx: int = 0, hit_limit: bool=False) -> Dict[str, Dict[str, torch.Tensor]]:
        """从hooks收集的隐状态中提取特征"""
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in needed_layers}

        # 使用预计算的零张量，避免重复创建
        _zero = self.hidden_zero

        for k in needed_layers:
            if k not in self.captured_hidden_states:
                continue

            layer_states = self.captured_hidden_states[k]
            if not layer_states:
                console.log("[yellow]Warn: no output tokens generated.[/yellow]")
                layer_features[k]["prompt_last_token"] = _zero.clone()
                layer_features[k]["answer_first_token"] = _zero.clone()
                layer_features[k]["last_token"] = _zero.clone()
                continue

            first_step_state = layer_states[0]

            # 计算 prompt_last_token
            if first_step_state.dim() == 3:
                # [batch, seq_len, hidden]
                if sample_idx < first_step_state.size(0) and first_step_state.size(1) > 0:
                    prompt_last = first_step_state[sample_idx, -1]      # [hidden]
                else:
                    prompt_last = _zero
            elif first_step_state.dim() == 2:
                # [seq_len, hidden]，单样本的情况
                if first_step_state.size(0) > 0:
                    prompt_last = first_step_state[-1]                  # [hidden]
                else:
                    prompt_last = _zero
            else:
                prompt_last = _zero

            layer_features[k]["prompt_last_token"] = prompt_last.clone()

            # 收集输出部分的隐状态
            output_states = []

            # 处理后续的前向传播结果
            # 每个后续的前向传播都会产生一个新的token的隐状态
            for i in range(1, len(layer_states)):
                step_state = layer_states[i]

                # 处理维度
                if step_state.dim() == 3:
                    # 批处理模式: [batch_size, seq_len, hidden_dim]
                    if sample_idx >= step_state.size(0):
                        raise RuntimeError("样本索引越界，请检查批处理逻辑")
                    step_hidden = step_state[sample_idx]  # [seq_len, hidden_dim]
                elif step_state.dim() == 2:
                    # 单样本模式: [seq_len, hidden_dim]
                    step_hidden = step_state
                else:
                    continue

                # 新生成的token的隐状态在序列的最后一个位置
                if len(step_hidden) > 0:
                    output_states.append(step_hidden[-1:])  # [1, hidden_dim]

            # console.log(f"[blue][DEBUG] 样本 {sample_idx} 层 {k} 生成 token 数: {len(output_states)}[/blue], right_pad_len={right_pad_len}, hit_limit={hit_limit}")

            if len(output_states) > right_pad_len:
                # 只取实际输出长度的隐状态
                if right_pad_len > 0:
                    output_states = output_states[:-right_pad_len]
            else:
                output_states = []

            if len(output_states) >= 1:
                output_first_step = output_states[0]    # [1, D] 或 [D]
                if output_first_step.dim() == 2 and output_first_step.size(0) > 0:
                    answer_first = output_first_step[-1]          # [D]
                elif output_first_step.dim() == 1:
                    answer_first = output_first_step
                else:
                    answer_first = _zero
                layer_features[k]["answer_first_token"] = answer_first.clone()
            else:
                layer_features[k]["answer_first_token"] = _zero.clone()

            # 计算真正最后一个 token 的隐状态
            if len(output_states) >= 1:
                normed = []
                for t in output_states:
                    if t.dim() == 1:
                        t = t.unsqueeze(0)
                    normed.append(t)
                # 拼接得到 [effective_len, hidden_dim]
                output_hidden = torch.cat(normed, dim=0)
                # 3. 最后一个token的表征
                vec_last_token = output_hidden[-1]
                layer_features[k]["last_token"] = vec_last_token.clone()
            else:
                layer_features[k]["last_token"] = _zero.clone()
        return layer_features

    def _sort_data_by_prompt_len(self, data):
        """
        输入: data(List[Dict]) 多模态数据  输出: 排序后的 data(List[Dict])
        注意：现在仅支持多模态数据，每个样本必须包含图像
        """
        # 验证所有样本都包含图像
        for i, item in enumerate(data):
            if not (("image_path" in item) or ("image" in item) or ("img_path" in item)):
                raise ValueError(f"样本 {i} 缺少图像字段，当前仅支持多模态数据")

        def _build_prompt_texts(items):
            # 用 parse_input 拿最终 prompt 文本
            questions = [parse_input(it) for it in items]

            # 如果 tokenizer 支持 chat template，就按后续 generate 的方式展开
            prompt_txts = []
            add_special = True
            has_tpl = bool(getattr(self.tokenizer, "chat_template", None))
            if hasattr(self.tokenizer, "apply_chat_template") and has_tpl:
                add_special = False
                for q in questions:
                    messages = [{"role": "user", "content": q}]
                    try:
                        if "qwen3" in str(self.model_key).lower():
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                                enable_thinking=True,
                            )
                        else:
                            txt = self.tokenizer.apply_chat_template(
                                messages,
                                tokenize=False,
                                add_generation_prompt=True,
                            )
                    except TypeError:
                        # 老版本 tokenizer 兼容
                        txt = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    prompt_txts.append(txt)
            else:
                # 非 chat tokenizer，直接用文本
                prompt_txts = questions

            return prompt_txts, add_special

        # 纯 CPU 分词测长度，分块以避免一次性处理太大
        LENS_CHUNK = 256
        n = len(data)
        lens = [0] * n
        for i in range(0, n, LENS_CHUNK):
            sub = data[i:i + LENS_CHUNK]
            prompt_txts, add_special = _build_prompt_texts(sub)
            enc = self.tokenizer(
                prompt_txts,
                add_special_tokens=add_special,
                truncation=False,
                return_attention_mask=False,
                return_token_type_ids=False,
            )
            for k, ids in enumerate(enc["input_ids"]):
                lens[i + k] = len(ids)

        # 长 → 短 排序
        order = sorted(range(n), key=lambda t: (lens[t], t), reverse=True)
        return [data[idx] for idx in order]

    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 1024,
        needed_layers: List[str] = None,
        oom_estimate: bool = False,
        bs_estimate_gen_cfg: Dict = {}
    ) -> Tuple[List[Dict[str, Dict[str, torch.Tensor]]], List[str], List[bool]]:
        """
        批量推理：只负责hook管理和调用backend
        """
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 1) 注册 hooks
        self.is_batch_mode = True
        self._register_hooks(needed_layers)

        try:
            # 2) 完全把「构造输入 + 调模型 + 算长度」交给 backend
            outputs, right_pad_lens, hit_limit_flags = self.backend.generate_batch(
                items=items,
                max_new_tokens=max_new_tokens,
                gen_cfg={},
                oom_estimate=oom_estimate,
                bs_estimate_gen_cfg=bs_estimate_gen_cfg
            )

            batch_layer_features = []
            batch_answers = []

            # 3) 遍历 batch，根据长度从 hooks 里抽特征
            for i in range(len(items)):
                layer_features = self._extract_features_from_hooks(
                    right_pad_len=right_pad_lens[i],
                    needed_layers=needed_layers,
                    sample_idx=i,
                    hit_limit=hit_limit_flags[i],
                )
                batch_layer_features.append(layer_features)
                batch_answers.append(trim_incomplete_sentence(outputs[i]))
        except Exception as e:
            raise RuntimeError(f"批量推理出错: {e}") from e

        return batch_layer_features, batch_answers, hit_limit_flags

    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 1024,
        needed_layers: List[str] = None,
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], str, bool]:
        """
        单样本推理：只负责hook管理和调用backend
        """
        if needed_layers is None:
            needed_layers = list(self.layer_map.keys())

        # 生成配置可以在外面统一设置，也可以允许调用方传进来
        # 1) 注册 hooks
        self._register_hooks(needed_layers)

        try:
            # 2) 交给 backend 处理「构造输入 + 调用模型 + 计算长度」
            output, right_pad_len, hit_limit = self.backend.generate_one(
                item=item,
                max_new_tokens=max_new_tokens,
                gen_cfg={},
            )
            pred = trim_incomplete_sentence(output)

            # 3) 根据长度，从已捕获的 hidden_states 中抽特征
            layer_features = self._extract_features_from_hooks(
                right_pad_len,
                needed_layers=needed_layers,
                sample_idx=0,       # 单样本
                hit_limit=hit_limit,
            )
        except Exception as e:
            raise RuntimeError(f"单样本推理出错: {e}") from e

        return layer_features, pred, hit_limit


    def _save_batch_features(
        self,
        feats_prompt_last_token: Dict[str, List[torch.Tensor]],
        feats_answer_first_token: Dict[str, List[torch.Tensor]],
        feats_last_token: Dict[str, List[torch.Tensor]],
        labels: List[int],
        ids: List[str],
        questions: List[str],
        true_answers: List[str],
        pred_answers: List[str],
        categories: List[str],
        options: List[List[str]],
        parsed_answers: List[str],
        incomplete_flags: List[bool],
        prompt_last_token_dir: str,
        answer_first_token_dir: str,
        last_token_outputdir: str,
    ) -> None:
        """保存当前批次的特征"""
        def save_features(feats, output_dir, des):
            for k, vec_list in feats.items():
                if not vec_list:
                    continue
                out_path = os.path.join(output_dir, f"{k}_features.pt")

                # 如果文件已存在，加载并合并
                if os.path.exists(out_path):
                    existing_data = torch.load(out_path, map_location="cpu")

                    # Id 去重，防止在 batch 处理完，pt写入但是 id 没有及时写入导致的重复处理
                    existing_ids = existing_data["ids"]
                    existing_id_set = set(existing_ids)

                    if len(vec_list) != len(ids):
                        raise ValueError(f"{k}: vec_list({len(vec_list)}) 与 ids({len(ids)}) 长度不一致")

                    # 以当前要保存的这一路特征 vec_list 为例，先把它和同索引的 meta 列表打包
                    new_records = list(zip(ids, labels, questions, true_answers, pred_answers,
                                        categories, options, parsed_answers, incomplete_flags, vec_list))

                    # 过滤：仅保留还未出现过的 id
                    new_records = [r for r in new_records if r[0] not in existing_id_set]

                    if new_records:
                        n_ids, n_labels, n_qs, n_trues, n_preds, n_cats, n_opts, n_pars, n_incomp, n_vecs = zip(*new_records)
                        tensor_new = torch.stack(list(n_vecs)).cpu()
                        tensor = torch.cat([existing_data["features"], tensor_new], dim=0)

                        merged_ids = existing_ids + list(n_ids)
                        merged_labels = existing_data["labels"].tolist() + list(n_labels)
                        merged_questions = existing_data["questions"] + list(n_qs)
                        merged_true_answers = existing_data["true_answers"] + list(n_trues)
                        merged_pred_answers = existing_data["pred_answers"] + list(n_preds)
                        merged_categories = existing_data["categories"] + list(n_cats)
                        merged_options = existing_data["options"] + list(n_opts)
                        merged_parsed_answers = existing_data["parsed_answers"] + list(n_pars)
                        merged_incomplete = existing_data["incomplete_flags"] + list(n_incomp)
                    else:
                        # 没有新增；直接复用 existing_data
                        tensor = existing_data["features"]
                        merged_ids = existing_ids
                        merged_labels = existing_data["labels"].tolist()
                        merged_questions = existing_data["questions"]
                        merged_true_answers = existing_data["true_answers"]
                        merged_pred_answers = existing_data["pred_answers"]
                        merged_categories = existing_data["categories"]
                        merged_options = existing_data["options"]
                        merged_parsed_answers = existing_data["parsed_answers"]
                        merged_incomplete = existing_data["incomplete_flags"]
                else:
                    # 首次写入
                    tensor = torch.stack(vec_list)
                    merged_ids = ids
                    merged_labels = labels
                    merged_questions = questions
                    merged_true_answers = true_answers
                    merged_pred_answers = pred_answers
                    merged_categories = categories
                    merged_options = options
                    merged_parsed_answers = parsed_answers
                    merged_incomplete = incomplete_flags

                tmp = out_path + ".tmp"
                torch.save({
                    "features": tensor,
                    "labels": torch.tensor(merged_labels),
                    "ids": merged_ids,
                    "questions": merged_questions,
                    "true_answers": merged_true_answers,
                    "pred_answers": merged_pred_answers,
                    "categories": merged_categories,
                    "options": merged_options,
                    "parsed_answers": merged_parsed_answers,
                    "incomplete_flags": merged_incomplete
                }, tmp)
                os.replace(tmp, out_path)
                console.print(f"[green]保存 {k} {des} → {tensor.shape} 到 {out_path}[/green]")

        save_features(feats_prompt_last_token, prompt_last_token_dir, des="prompt最后一个token的表征")
        save_features(feats_answer_first_token, answer_first_token_dir, des="answer第一个token的表征")
        save_features(feats_last_token, last_token_outputdir, des="最后一个token的表征")

    def extract_dataset(
        self,
        dataset_name: str,
        dataset_path: str,
        prompt_last_token_dir: str,
        answer_first_token_dir: str,
        last_token_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 1024,
        batch_size: int = 4,
        resume: bool = False,
        checkpoint_root: Optional[str] = None,
        oom_estimate: bool = False
    ) -> None:
        os.makedirs(prompt_last_token_dir, exist_ok=True)
        os.makedirs(answer_first_token_dir, exist_ok=True)
        os.makedirs(last_token_outputdir, exist_ok=True)

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
        if self.debug_mode:
            data = data[:250]
        console.rule(f"处理数据集 → {dataset_name}")

        needed_layers: List[str] = (
            list(self.layer_map.keys()) if layer_req == "all" else [layer_req]
        )

        # 如果没有指定断点目录，则路径结构与 feats 保持一致：{output_dir}/{model_key}_avg_with_prompt/{dataset_name}/checkpoint_meta.json
        checkpoint_dir = os.path.join(os.path.normpath(checkpoint_root), self.model_key, dataset_name)

        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
        ckpt = CheckpointManager(checkpoint_path, load_existing=resume)

        if not resume:
            ckpt.reset()  # 非断点模式直接从头跑
            # 同时删除历史断点文件和特征文件
            if os.path.exists(checkpoint_path):
                os.remove(checkpoint_path)
            clear_dir(prompt_last_token_dir)
            clear_dir(answer_first_token_dir)
            clear_dir(last_token_outputdir)

        # 2) 预清洗 + 统计无效 ID
        cleaned_data: List[Dict[str, Any]] = []
        bad_id_count = 0
        for item in data:
            _id = get_item_id(item)
            if _id == "":
                bad_id_count += 1
                continue
            # 统一id描述
            item["_resolved_id"] = _id
            cleaned_data.append(item)
        if bad_id_count:
            print(f"[warn] {bad_id_count} 条样本缺少 id/unique_id，已跳过。")

        ckpt.set_total(len(cleaned_data))

        # 3) 基于 checkpoint 的 processed_ids 进行过滤（在排序前）
        if resume:
            remaining = [it for it in cleaned_data if not ckpt.is_processed(it["_resolved_id"])]
        else:
            remaining = cleaned_data  # 完全从头开始

        if not remaining:
            print("[info] 没有剩余样本可处理。")
            return

        # 2. 按照 prompt token 长度从长到短排序（排序在切片之前）
        # data = self._sort_data_by_prompt_len(remaining)
        data=remaining
        # console.print("[cyan]已按prompt长度排序[/cyan]")

        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")

        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]

            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")

            # 用于存储当前批次的特征和标签
            batch_feats_prompt_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_answer_first_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_feats_last_token: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
            batch_labels: List[int] = []
            batch_ids: List[str] = []
            batch_questions: List[str] = []
            batch_true_answers: List[str] = []
            batch_pred_answers: List[str] = []
            batch_categories: List[str] = []
            batch_options: List[List[str]] = []
            batch_parsed_answers: List[str] = []
            batch_incomplete_flags: List[bool] = []

            # 标记当前批次是否完整处理完毕
            batch_completed = False

            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers, hit_limit_flags = self.forward_batch(
                    batch_items,
                    max_new_tokens=max_new_tokens,
                    needed_layers=needed_layers,
                    oom_estimate=oom_estimate,
                    bs_estimate_gen_cfg={}
                )

                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["question"].strip()
                    # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                    # console.print(f"model_answer {start_idx + i}: {answer}")

                    batch_ids.append(item.get("_resolved_id", ""))
                    batch_questions.append(q)
                    batch_true_answers.append(item["answer"])
                    batch_options.append(get_options_list(item))
                    batch_categories.append(item.get("subpart", ""))  # 使用 subpart 作为分类信息
                    batch_pred_answers.append(answer)

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        batch_feats_prompt_last_token[k].append(layer_vecs[k]["prompt_last_token"])
                        batch_feats_answer_first_token[k].append(layer_vecs[k]["answer_first_token"])
                        batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                    # incomplete 截断
                    if hit_limit_flags[i]:
                        answer = trim_incomplete_sentence(answer)

                    extracted_answer, is_correct = seedbench_parse(answer, item["answer"], hit_limit_flags[i], get_option_answer_text(item))

                    if extracted_answer == "Incomplete" or hit_limit_flags[i]:
                        console.print("\n[red]Incomplete answer[/red]")

                        # 测试一下answer输出
                        # print("answer:", answer)
                        # print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                        if extracted_answer is None:
                            print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                        else:
                            if len(extracted_answer) < 500:
                                print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                            else:
                                print("pred_answer:", "Incomplete", "true answer:",item["answer"], "is_correct:", is_correct)
                    else:
                        print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                    batch_labels.append(int(is_correct))
                    batch_parsed_answers.append(extracted_answer)
                    batch_incomplete_flags.append(hit_limit_flags[i])

                # 标记批次完成
                batch_completed = True

            except Exception as e:
                console.print(f"[red]批次 {batch_idx + 1} 处理失败: {escape(str(e))}[/red]")
                console.print("[yellow]回退到单样本处理模式[/yellow]")

                # 仅在OOM时尝试清理
                if "out of memory" in str(e).lower():
                    gc.collect()
                    torch.cuda.empty_cache()

                # 清空当前批次已收集的数据，避免部分数据重复保存
                if batch_ids:
                    console.print("[yellow]清空当前批次已收集的数据，避免部分数据重复保存[/yellow]")
                    batch_feats_prompt_last_token = {k: [] for k in needed_layers}
                    batch_feats_answer_first_token = {k: [] for k in needed_layers}
                    batch_feats_last_token = {k: [] for k in needed_layers}
                    batch_labels.clear()
                    batch_ids.clear()
                    batch_questions.clear()
                    batch_true_answers.clear()
                    batch_pred_answers.clear()
                    batch_categories.clear()
                    batch_options.clear()
                    batch_parsed_answers.clear()
                    batch_incomplete_flags.clear()

                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer, hit_limit = self.forward_once(
                            item,
                            max_new_tokens=max_new_tokens,
                            needed_layers=needed_layers
                        )

                        q = item["question"].strip()
                        # console.print(f"\n[green]question {start_idx + i}: {escape(q)}[/green]")
                        # console.print(f"model_answer {start_idx + i}: {answer}", markup=False)

                        batch_ids.append(item.get("_resolved_id", ""))
                        batch_questions.append(q)
                        batch_true_answers.append(item["answer"])
                        batch_options.append(get_options_list(item))
                        batch_categories.append(item.get("subpart", ""))  # 使用 subpart 作为分类信息
                        batch_pred_answers.append(answer)

                        # 添加特征
                        for k in needed_layers:
                            batch_feats_prompt_last_token[k].append(layer_vecs[k]["prompt_last_token"])
                            batch_feats_answer_first_token[k].append(layer_vecs[k]["answer_first_token"])
                            batch_feats_last_token[k].append(layer_vecs[k]["last_token"])

                        # 被截断的文本需要去掉最后一句不完整的话
                        if hit_limit:
                            answer = trim_incomplete_sentence(answer)

                        extracted_answer, is_correct = seedbench_parse(answer, item["answer"], hit_limit, get_option_answer_text(item))

                        if extracted_answer == "Incomplete" or hit_limit:
                            console.print("\n[red]Incomplete answer[/red]")

                            if extracted_answer is None:
                                print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                            else:
                                if len(extracted_answer) < 500:
                                    print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                                else:
                                    print("pred_answer:", "Incomplete", "true answer:",item["answer"], "is_correct:", is_correct)
                        else:
                            print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)

                        batch_parsed_answers.append(extracted_answer)
                        # console.print("pred_answer:", extracted_answer, "true answer:",item["answer"], "is_correct:", is_correct)
                        batch_labels.append(int(is_correct))
                        batch_incomplete_flags.append(hit_limit)

                    except Exception as single_e:
                        # console.print(f"[red]样本 {start_idx + i} 处理失败: {escape(str(single_e))}[/red]")
                        # # 添加空特征以保持索引一致
                        # batch_ids.append(item.get("_resolved_id", ""))
                        # batch_questions.append(item["question"].strip())
                        # batch_true_answers.append(item["answer"])
                        # batch_options.append(get_options_list(item))
                        # batch_categories.append(item.get("subpart", ""))  # 使用 subpart 作为分类信息
                        # batch_pred_answers.append("Failed")
                        # batch_parsed_answers.append("Failed")
                        # batch_labels.append(0)
                        # batch_incomplete_flags.append(False)
                        # # 添加零特征
                        # for k in needed_layers:
                        #     batch_feats_prompt_last_token[k].append(self.hidden_zero.clone())
                        #     batch_feats_answer_first_token[k].append(self.hidden_zero.clone())
                        #     batch_feats_last_token[k].append(self.hidden_zero.clone())
                        raise RuntimeError(f"样本 {start_idx + i} 处理失败: {single_e}") from single_e

                # 单样本回退模式下，整个batch处理完成才标记完成
                batch_completed = True

            # 只有当批次完整处理完毕时才保存
            if batch_completed:
                # 保存当前批次的特征
                self._save_batch_features(
                    batch_feats_prompt_last_token,
                    batch_feats_answer_first_token,
                    batch_feats_last_token,
                    batch_labels,
                    batch_ids,
                    batch_questions,
                    batch_true_answers,
                    batch_pred_answers,
                    batch_categories,
                    batch_options,
                    batch_parsed_answers,
                    batch_incomplete_flags,
                    prompt_last_token_dir,
                    answer_first_token_dir,
                    last_token_outputdir,
                )

                # 写入断点（即使 resume=False 也会持续写，确保中断可恢复）
                ckpt.mark_batch_processed(batch_ids)

            # 清理GPU内存
            # torch.cuda.empty_cache()

        print(f"[done] 共 {len(remaining)} 条新样本完成。累计完成 {len(ckpt.processed_ids)}/{ckpt.total_samples}.")
