from __future__ import annotations
import argparse
import json
import os
from typing import Dict, List, Optional, Tuple

import torch
from rich.console import Console
from rich.progress import track
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import logging
import re
import sys

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from utils.config import CONFIG
from config import QUERY_TEMPLATE_2, QUERY_TEMPLATE_3, QUERY_TEMPLATE_4, QUERY_TEMPLATE_5, CHOICES

logging.set_verbosity_error() # 仅显示ERROR信息，忽略WARNING
console = Console()

def safe_rich_print(text: str, style: str = "red") -> None:
    """
    安全地使用Rich库打印文本，转义可能导致MarkupError的特殊字符
    """
    # 转义方括号和其他可能导致问题的字符
    escaped_text = str(text).replace("[", "\\[").replace("]", "\\]")
    console.print(f"[{style}]{escaped_text}[/{style}]")

MODEL_PATH = os.path.join(PROJECT_ROOT, "models")
DATASETS = ["arc_challenge_train", "arc_challenge_val", "arc_challenge_test"]

MODELS: Dict[str, str] = {
    "llama3.1-8b-instruct": os.path.join(MODEL_PATH, "llama3.1-8b-instruct"),
    "qwen2.5-7b-instruct": os.path.join(MODEL_PATH, "qwen2.5-7b-instruct"),
    "phi3.5-mini-4b-instruct": os.path.join(MODEL_PATH, "phi3.5-mini-4b-instruct"),
    "mistral-7b-instruct-v0.2": os.path.join(MODEL_PATH, "mistral-7b-instruct-v0.2"),
    "glm-4-9b-chat": os.path.join(MODEL_PATH, "glm-4-9b-chat"),
    "qwen2.5-3b-instruct": os.path.join(MODEL_PATH, "qwen2.5-3b-instruct"),
    "qwen2.5-14b-instruct": os.path.join(MODEL_PATH, "qwen2.5-14b-instruct"),
    "Qwen3-4B": os.path.join(MODEL_PATH, "Qwen3-4B"),
}

def parse_input(example):
    question = example["question"]
    texts = example["choices"]["text"]
    labels = example["choices"]["label"]
    
    if len(labels) == 2:
        prompt = QUERY_TEMPLATE_2.format(question=question, textA=texts[0], textB=texts[1])
    elif len(labels) == 3:
        prompt = QUERY_TEMPLATE_3.format(question=question, textA=texts[0], textB=texts[1], textC=texts[2])
    elif len(labels) == 4:
        prompt = QUERY_TEMPLATE_4.format(question=question, textA=texts[0], textB=texts[1], textC=texts[2], textD=texts[3])
    elif len(labels) == 5:
        prompt = QUERY_TEMPLATE_5.format(question=question, textA=texts[0], textB=texts[1], textC=texts[2], textD=texts[3], textE=texts[4])
    else:
        raise ValueError(f"Unsupported number of choices: {len(labels)}")
    return prompt

# 子串匹配
def _find_subseq(haystack: List[int], needle: List[int]) -> Optional[int]:
    n, m = len(haystack), len(needle)
    for i in range(n - m + 1):
        if haystack[i : i + m] == needle:
            return i
    return None


# 预编译一次即可
def _build_answer_regex(choices: str) -> re.Pattern:
    """
    生成形如 'Answer: C' '答案是C' 'ANSWER: C.' 等格式的正则；
    choices 例如 'ABCDE'
    """
    choice_re = f"[{choices}]"
    return re.compile(
        rf"""
        # 关键词，中英文均可，大小写不敏感
        (?:答案|answer|final\s+answer|the\s+answer)
        (?:\s*(?:is|为|:?|：)?\s*)?   # 可选 'is' '为' 冒号等
        \(?                          # 可选左括号
        (?<![A-Za-z0-9])({choice_re})(?![A-Za-z0-9])                # 捕获答案字母
        \)?                          # 可选右括号
        """,
        re.IGNORECASE | re.VERBOSE,
    )

_RE_CACHE: dict[str, re.Pattern] = {}  # 缓存不同 choices 的 regex

def extract_arc_answer(text: str, choices: str = "ABCDE") -> Optional[str]:
    if not any(token in text for token in CONFIG["EOS_TOKEN"]):
        return "Incomplete"
    """
    从模型输出文本中抽取 ARC 答案字母。
    优先使用精确格式匹配；若失败，回退到扫末行独立字母。
    """
    if choices not in _RE_CACHE:
        _RE_CACHE[choices] = _build_answer_regex(choices)

    regex = _RE_CACHE[choices]

    # 1) 逐行扫描，优先检查最后一行 → 第一行，减少误判
    for line in reversed(text.strip().splitlines()):
        line = line.strip()
        if not line:
            continue
        m = regex.search(line)
        if m:
            return m.group(1).upper()

    # 2) 回退：在最后一行查找独立字母
    last_line = text.strip().splitlines()[-1].strip() if text.strip() else ""
    m = re.search(rf"(?<![A-Za-z0-9])([{choices}])(?![A-Za-z0-9])", last_line, re.IGNORECASE)
    if m:
        return m.group(1).upper()

    # 3) 仍未找到
    return None


class HiddenExtractor:
    def __init__(self, model_key: str, model_path: str, device: str = "cuda") -> None:
        self.model_key = model_key
        self.device = device
        console.print(f"[bold]加载模型:[/bold] {model_path}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                # use_flash_attn=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation="flash_attention_2"
            )
        except:
            # 防止某些模型不支持flash-attention
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                torch_dtype=torch.bfloat16,
                device_map="auto",
            )
        self.layer_map = self._build_layer_index()

    def _build_layer_index(self) -> Dict[str, int]:
        n_layers = self.model.config.num_hidden_layers
        mapping = {
            "middle": n_layers // 2,
            "last": n_layers,
            "second_last": n_layers - 1,
        }
        console.print(
            f"总层数: {n_layers} → middle={mapping['middle']}, "
            f"second_last={mapping['second_last']}, last={mapping['last']}"
        )
        return mapping
    
    # 方便统一处理chat/base模型
    def _build_input_ids(self, question: str) -> Tuple[torch.Tensor, list[int], str]:
        """
        统一构建 input_ids, 返回:
        - input_ids  : (1, seq_len) tensor on self.device
        - prompt_ids : list[int] 同一内容, 用于后续定位
        - prompt_txt : str, decode 后的完整 prompt 文本
        """
        # chat / instruct 模型：tokenizer 支持 chat template
        if hasattr(self.tokenizer, "apply_chat_template"):
            prompt_struct = [{"role": "user", "content": question}]
            if "qwen3" in self.model_key:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    enable_thinking=False,
                )
            else:
                encoded = self.tokenizer.apply_chat_template(
                    prompt_struct,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                )
        else:
            batch = self.tokenizer(
                question,
                add_special_tokens=True,      # 先让 tokenizer 自己加他心目中的特殊符号
                return_tensors="pt",
                truncation=True,
            )
            input_ids = batch["input_ids"]   # shape: (1, L)
            bos_id = self.tokenizer.bos_token_id
            if bos_id is not None and input_ids[0, 0].item() != bos_id:
                bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                input_ids = torch.cat([bos_tensor, input_ids], dim=1)

            encoded = input_ids

        encoded = encoded.to(self.device)
        prompt_ids = encoded[0].tolist()
        prompt_txt = self.tokenizer.decode(encoded[0], skip_special_tokens=False)
        return encoded, prompt_ids, prompt_txt

    def _build_batch_input_ids(self, questions: List[str]) -> Tuple[torch.Tensor, List[List[int]], List[str]]:
        """
        批量构建 input_ids, 返回:
        - input_ids  : (batch_size, max_seq_len) tensor on self.device
        - prompt_ids_list : List[List[int]] 每个样本的prompt_ids
        - prompt_txts : List[str] 每个样本的prompt文本
        """
        batch_input_ids = []
        prompt_ids_list = []
        prompt_txts = []
        
        for question in questions:
            # chat / instruct 模型：tokenizer 支持 chat template
            if hasattr(self.tokenizer, "apply_chat_template"):
                prompt_struct = [{"role": "user", "content": question}]
                if "qwen3" in self.model_key:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=False,
                    )
                else:
                    encoded = self.tokenizer.apply_chat_template(
                        prompt_struct,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                    )
            else:
                batch = self.tokenizer(
                    question,
                    add_special_tokens=True,
                    return_tensors="pt",
                    truncation=True,
                )
                input_ids = batch["input_ids"]
                bos_id = self.tokenizer.bos_token_id
                if bos_id is not None and input_ids[0, 0].item() != bos_id:
                    bos_tensor = torch.tensor([[bos_id]], dtype=input_ids.dtype)
                    input_ids = torch.cat([bos_tensor, input_ids], dim=1)
                encoded = input_ids
            
            batch_input_ids.append(encoded[0])
            prompt_ids_list.append(encoded[0].tolist())
            prompt_txts.append(self.tokenizer.decode(encoded[0], skip_special_tokens=False))
        
        # 设置pad_token
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
        # 批量padding
        padded_input_ids = torch.nn.utils.rnn.pad_sequence(
            batch_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        ).to(self.device)
        
        return padded_input_ids, prompt_ids_list, prompt_txts

    # 单次前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_once(
        self,
        item: dict,
        max_new_tokens: int = 2048,
    ) -> Tuple[Dict[str, torch.Tensor], str]:
        # 1) 构建 prompt 
        question = item["question"]
        prompt = parse_input(item)

        # 2) tokenize 
        input_ids, prompt_ids, prompt_txt = self._build_input_ids(prompt)

        pad_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else self.tokenizer.eos_token_id
        )

        if "qwen3" in self.model_key:
            max_new_tokens = 4096
        # 3) 隐状态生成
        gen_out = self.model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=pad_id,
            output_attentions=True,
            output_hidden_states=True,
            output_scores=True,
            return_dict_in_generate=True,
            generation_config=GenerationConfig()
        )

        # 4) 答案生成解码
        full_txt = self.tokenizer.decode(gen_out.sequences[0], skip_special_tokens=False)
        prompt_txt = self.tokenizer.decode(input_ids[0], skip_special_tokens=False)
        answer_txt = full_txt[len(prompt_txt) :].lstrip()

        # 5) 提取隐状态
        # gen_out.hidden_states 是长度为 generation_len 的层元组列表。
        hidden_step0: Tuple[torch.Tensor, ...] = gen_out.hidden_states[0] # shape: (num_layers, 1, seq_len, hidden_dim)
        
        layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in self.layer_map}

        # 定位最后一个问题标记以池化最后一个标记表示
        question_ids = self.tokenizer(question, add_special_tokens=False).input_ids
        q_start = _find_subseq(prompt_ids, question_ids)
        last_q_idx = (
            (q_start + len(question_ids) - 1) if q_start is not None else len(prompt_ids) - 1
        )

        for k, idx in self.layer_map.items():
            # average
            vec_avg = hidden_step0[idx][0, q_start:last_q_idx+1, :].mean(dim=0).to("cpu")
            # last question token
            vec_last = hidden_step0[idx][0, last_q_idx, :].to("cpu")
            # last prompt token
            vec_last_prompt = hidden_step0[idx][0, -1, :].to("cpu")
            layer_features[k]["avg"] = vec_avg
            layer_features[k]["last_question"] = vec_last
            layer_features[k]["last_prompt"] = vec_last_prompt
        return layer_features, answer_txt

    # 批量前向传播同时获取隐状态和label
    @torch.inference_mode()
    def forward_batch(
        self,
        items: List[dict],
        max_new_tokens: int = 2048,
    ) -> Tuple[List[Dict[str, torch.Tensor]], List[str]]:
        # 1) 构建批量 prompts
        questions = [item["question"] for item in items]
        prompts = [parse_input(item) for item in items]

        # 2) 批量 tokenize
        batch_input_ids, prompt_ids_list, prompt_txts = self._build_batch_input_ids(prompts)
        
        # pad_id = (
        #     self.tokenizer.pad_token_id
        #     if self.tokenizer.pad_token_id is not None
        #     else self.tokenizer.eos_token_id
        # )

        try:
            if self.tokenizer.pad_token_id is None:
                special_tokens_dict = {'pad_token': '<|pad|>'}
                num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
                if num_added > 0:
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    pad_id = self.tokenizer.pad_token_id

                    # 可选：将 pad embedding 置零
                    with torch.no_grad():
                        emb = self.model.get_input_embeddings().weight
                        emb[pad_id].zero_()
            else:
                pad_id = self.tokenizer.pad_token_id
        except:
            pad_id = (
                self.tokenizer.pad_token_id
                if self.tokenizer.pad_token_id is not None
                else self.tokenizer.eos_token_id
            )

        attention_mask = batch_input_ids.ne(pad_id)

        if "qwen3" in self.model_key:
            max_new_tokens = 4096
            
        # 3) 批量隐状态生成
        gen_out = self.model.generate(
            input_ids=batch_input_ids,
            max_new_tokens=max_new_tokens,
            attention_mask=attention_mask,
            do_sample=False,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=pad_id,
            output_attentions=True,
            output_hidden_states=True,
            output_scores=True,
            return_dict_in_generate=True,
            generation_config=GenerationConfig()
        )

        # 4) 批量答案解码
        batch_answers = []
        for i in range(len(items)):
            # 拿到第 i 条样本
            answer_ids = gen_out.sequences[i]
            filtered_ids = answer_ids[answer_ids.ne(pad_id)]
            prompt_len = batch_input_ids[i].ne(pad_id).sum().item()
            filtered_ids = filtered_ids[prompt_len:]
            answer_txt = self.tokenizer.decode(filtered_ids, skip_special_tokens=False)
            batch_answers.append(answer_txt)

        # 5) 批量提取隐状态
        batch_layer_features = []
        hidden_step0: Tuple[torch.Tensor, ...] = gen_out.hidden_states[0]  # shape: (num_layers, batch_size, seq_len, hidden_dim)
        
        for i, (prompt_ids, question) in enumerate(zip(prompt_ids_list, questions)):
            layer_features: Dict[str, Dict[str, torch.Tensor]] = {k: {} for k in self.layer_map}
            
            # 定位最后一个问题标记以池化最后一个标记表示
            question_ids = self.tokenizer(question, add_special_tokens=False).input_ids
            q_start = _find_subseq(prompt_ids, question_ids)
            last_q_idx = (
                (q_start + len(question_ids) - 1) if q_start is not None else len(prompt_ids) - 1
            )
            if q_start is None:
                q_start = last_q_idx
            
            # 获取当前样本的实际序列长度（去除padding）
            actual_len = len(prompt_ids)
            
            for k, idx in self.layer_map.items():
                # average
                vec_avg = hidden_step0[idx][i, q_start:last_q_idx+1, :].mean(dim=0).to("cpu")
                # last question token
                vec_last = hidden_step0[idx][i, last_q_idx, :].to("cpu")
                # last prompt token (实际序列的最后一个token)
                vec_last_prompt = hidden_step0[idx][i, actual_len-1, :].to("cpu")
                layer_features[k]["avg"] = vec_avg
                layer_features[k]["last_question"] = vec_last
                layer_features[k]["last_prompt"] = vec_last_prompt
            
            batch_layer_features.append(layer_features)
            
        return batch_layer_features, batch_answers

    def extract_dataset(
        self,
        dataset_path: str,
        last_question_outputdir: str,
        avg_outputdir: str,
        last_prompt_outputdir: str,
        layer_req: str = "middle",
        max_new_tokens: int = 2048,
        batch_size: int = 4,
    ) -> None:
        os.makedirs(last_question_outputdir, exist_ok=True)
        os.makedirs(avg_outputdir, exist_ok=True)
        os.makedirs(last_prompt_outputdir, exist_ok=True)

        dataset_name = os.path.basename(dataset_path).lower()
        console.rule(f"处理数据集 → {dataset_name}")

        needed_layers: List[str] = (
            list(self.layer_map.keys()) if layer_req == "all" else [layer_req]
        )
        feats_last_question: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
        feats_avg: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
        feats_last_prompt: Dict[str, List[torch.Tensor]] = {k: [] for k in needed_layers}
        labels: List[int] = []
        ids: List[str] = []
        questions: List[str] = []
        true_answers: List[str] = []
        pred_answers: List[str] = []

        with open(dataset_path, "r", encoding="utf-8") as f:
            data = [json.loads(l) for l in f]
        data = data[:64]
        console.print(f"加载 {len(data)} 个样本.")

        # 批量处理数据集
        total_batches = (len(data) + batch_size - 1) // batch_size
        console.print(f"使用批量大小 {batch_size}, 总共 {total_batches} 个批次")
        
        for batch_idx in track(range(total_batches), description="Extracting features"):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))
            batch_items = data[start_idx:end_idx]
            
            console.print(f"处理批次 {batch_idx + 1}/{total_batches}, 样本 {start_idx}-{end_idx-1}")
            
            try:
                # 批量前向传播
                batch_layer_vecs, batch_answers = self.forward_batch(batch_items, max_new_tokens=max_new_tokens)
                
                # 处理批次结果
                for i, (item, layer_vecs, answer) in enumerate(zip(batch_items, batch_layer_vecs, batch_answers)):
                    q = item["question"].strip()
                    # console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                    # console.print(f"model_answer {start_idx + i}: {answer}")
                    
                    ids.append(item.get("id", ""))
                    questions.append(parse_input(item))
                    true_answers.append(item["answer"])

                    # 添加特征（无论正确与否，每个题目都需要添加特征）
                    for k in needed_layers:
                        feats_last_question[k].append(layer_vecs[k]["last_question"])
                        feats_last_prompt[k].append(layer_vecs[k]["last_prompt"])
                        feats_avg[k].append(layer_vecs[k]["avg"])
                    
                    extracted_answer = extract_arc_answer(answer, CHOICES[len(item["choices"]["label"])])
                    if extracted_answer == "Incomplete":
                        pred_answers.append("Incomplete")
                    else:
                        pred_answers.append(answer)

                    console.print(f"[green]extracted_answer {start_idx + i}: {extracted_answer}, true_answer: {item['answer']}[/green]")
                    is_correct = extracted_answer == item["answer"]
                    if extracted_answer is not None:
                        labels.append(int(is_correct))
                    else:
                        labels.append(0)
                        
                # 清理GPU内存
                torch.cuda.empty_cache()
                
            except Exception as e:
                safe_rich_print(f"批次 {batch_idx + 1} 处理失败: {e}", "red")
                console.print(f"[yellow]回退到单样本处理模式[/yellow]")
                
                # 回退到单样本处理
                for i, item in enumerate(batch_items):
                    try:
                        layer_vecs, answer = self.forward_once(item, max_new_tokens=max_new_tokens)
                        
                        q = item["question"].strip()
                        console.print(f"\n[green]question {start_idx + i}: {q}[/green]")
                        console.print(f"model_answer {start_idx + i}: {answer}")
                        
                        ids.append(item.get("id", ""))
                        questions.append(parse_input(item))
                        true_answers.append(item["answer"])

                        # 添加特征
                        for k in needed_layers:
                            feats_last_question[k].append(layer_vecs[k]["last_question"])
                            feats_last_prompt[k].append(layer_vecs[k]["last_prompt"])
                            feats_avg[k].append(layer_vecs[k]["avg"])
                        
                        extracted_answer = extract_arc_answer(answer, CHOICES[len(item["choices"]["label"])])
                        
                        if extracted_answer == "Incomplete":
                            pred_answers.append("Incomplete")
                        else:
                            pred_answers.append(answer)

                        console.print(f"[green]extracted_answer {start_idx + i}: {extracted_answer}, true_answer: {item['answer']}")
                        is_correct = extracted_answer == item["answer"]
                        if extracted_answer is not None:
                            labels.append(int(is_correct))
                        else:
                            labels.append(0)
                            
                    except Exception as single_e:
                        safe_rich_print(f"样本 {start_idx + i} 处理失败: {single_e}", "red")
                        # 添加空特征以保持索引一致
                        ids.append(item.get("id", ""))
                        questions.append(parse_input(item))
                        true_answers.append(item["answer"])
                        pred_answers.append("Failed")
                        labels.append(0)
                        
                        # 添加零特征
                        for k in needed_layers:
                            hidden_dim = self.model.config.hidden_size
                            feats_last_question[k].append(torch.zeros(hidden_dim))
                            feats_last_prompt[k].append(torch.zeros(hidden_dim))
                            feats_avg[k].append(torch.zeros(hidden_dim))

        # 标签+特征保存
        for k, vec_list in feats_last_question.items():
            tensor = torch.stack(vec_list)  # (N, hidden_dim)
            out_path = os.path.join(last_question_outputdir, f"{k}_features.pt")
            torch.save({
                "features": tensor, 
                "labels": torch.tensor(labels),
                "ids": ids,
                "questions": questions,
                "true_answers": true_answers,
                "pred_answers": pred_answers,
            }, out_path)
            console.print(f"[green]保存 {k} → {tensor.shape} 到 {out_path}")

        # avg
        for k, vec_list in feats_avg.items():
            tensor = torch.stack(vec_list)  # (N, hidden_dim)
            out_path = os.path.join(avg_outputdir, f"{k}_features.pt")
            torch.save({
                "features": tensor,
                "labels": torch.tensor(labels),
                "ids": ids,
                "questions": questions,
                "true_answers": true_answers,
                "pred_answers": pred_answers,
            }, out_path)
            console.print(f"[green]保存 {k} → {tensor.shape} 到 {out_path}")

        # 保存last_prompt
        for k, vec_list in feats_last_prompt.items():
            tensor = torch.stack(vec_list)  # (N, hidden_dim)
            out_path = os.path.join(last_prompt_outputdir, f"{k}_features.pt")
            torch.save({
                "features": tensor,
                "labels": torch.tensor(labels),
                "ids": ids,
                "questions": questions,
                "true_answers": true_answers,
                "pred_answers": pred_answers,
            }, out_path)
            console.print(f"[green]保存 {k} → {tensor.shape} 到 {out_path}")


def load_datasets(data_dir: str) -> Dict[str, str]:
    paths: Dict[str, str] = {}
    for d in DATASETS:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}")
    return paths


def main() -> None:
    p = argparse.ArgumentParser("Hidden feature extractor")
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT,"new_benchmark","arc_challenge"))
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT,"feats/arc_challenge/arc_feats"))
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="all")
    p.add_argument("--dataset", default="arc_challenge", choices=DATASETS)
    p.add_argument(
        "--layer_type",
        default="all",
        choices=["middle", "last", "second_last", "all"],
    )
    p.add_argument("--batch_size", type=int, default=4, help="批量处理的batch size")
    args = p.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    datasets = load_datasets(args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在")
        return

    models = MODELS.keys() if args.model == "all" else [args.model]
    for mk in models:
        if mk not in MODELS:
            console.print(f"[red] {mk} 模型不存在")
            continue

        console.rule(f"📦  Processing model – {mk}")
        extractor = HiddenExtractor(mk, MODELS[mk], args.device)

        for dname, dpath in datasets.items():
            last_question_out_dir = os.path.join(args.output_dir, mk+"_last_question", dname)
            avg_out_dir = os.path.join(args.output_dir, mk+"_avg", dname)
            last_prompt_out_dir = os.path.join(args.output_dir, mk+"_last_prompt", dname)
            extractor.extract_dataset(
                dpath,
                last_question_out_dir,
                avg_out_dir,
                last_prompt_out_dir,
                layer_req=args.layer_type,
                batch_size=args.batch_size,
            )

    console.rule("程序结束！")


if __name__ == "__main__":
    main()
