#!/usr/bin/env python3
"""
重构后的 VLM 隐藏特征提取主脚本
"""

import argparse
import os
import time
# 设置环境变量
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import sys
import json
from typing import Dict, List, Any, Tuple, Optional
import torch
from rich.console import Console
from hidden_extractor import HiddenExtractor
import copy
import gc
import csv

# 设置项目根目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, "..", ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

console = Console()
torch.manual_seed(42)

def load_datasets(dataset: List[str], data_dir: str) -> Dict[str, str]:
    """加载数据集路径（完全按照原始代码）"""
    paths: Dict[str, str] = {}
    for d in dataset:
        fp = os.path.join(data_dir, f"{d}.jsonl")
        if os.path.exists(fp):
            paths[d] = fp
            console.print(f"找到数据集: {fp}")
        else:
            console.print(f"[red] 数据集不存在 {fp}[/red]")
    return paths

# -----------------------------
# OOM helpers
# -----------------------------
def _build_oom_errors() -> tuple:
    errs = []
    for obj in (getattr(torch, "OutOfMemoryError", None), getattr(torch.cuda, "OutOfMemoryError", None)):
        if obj is not None and isinstance(obj, type):
            errs.append(obj)
    return tuple(errs) if errs else tuple()

OOM_ERRORS = _build_oom_errors()

def _maybe_oom_runtime_error(e: RuntimeError) -> bool:
    msg = str(e).lower()
    return ("out of memory" in msg) or ("cuda error" in msg)

def _empty_cache_all(delay_s: float = 0.2) -> None:
    gc.collect()
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                torch.cuda.synchronize()     # 等待该卡上所有 CUDA 工作完成
                torch.cuda.empty_cache()     # 清缓存
                torch.cuda.synchronize()     # 再同步一下更保险
        if delay_s > 0:
            time.sleep(delay_s)              # 给驱动/allocator 一点时间

def _iter_jsonl(path: str):
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln:
                yield json.loads(ln)

def _build_prompt_texts(items: List[Dict[str, Any]]):
    from utils import parse_input
    return  [parse_input(it) for it in items]

def find_worst_item_by_prompt_tokens(
    dataset_path: str,
    tokenizer,
    chunk_size: int = 256,
) -> Tuple[Dict[str, Any], int]:
    """
    遍历 shard，按 prompt 的 token 数找最长样本（流式 + 分块 tokenize）。
    """
    best_item: Optional[Dict[str, Any]] = None
    best_len = -1

    buf: List[Dict[str, Any]] = []
    seen = 0

    for item in _iter_jsonl(dataset_path):
        buf.append(item)
        seen += 1

        # 达到分块条件了再一次性 tokenize
        if len(buf) >= chunk_size:
            prompt_txts = _build_prompt_texts(buf)
            enc = tokenizer(
                prompt_txts,
                add_special_tokens=True,
                truncation=False,
                return_attention_mask=False,
            )
            for it, ids in zip(buf, enc["input_ids"]):
                L = len(ids)
                if L > best_len:
                    best_len = L
                    best_item = it
            buf.clear()

    if buf:
        prompt_txts = _build_prompt_texts(buf)
        enc = tokenizer(
            prompt_txts,
            add_special_tokens=True,
            truncation=False,
            return_attention_mask=False,
        )
        for it, ids in zip(buf, enc["input_ids"]):
            L = len(ids)
            if L > best_len:
                best_len = L
                best_item = it

    if best_item is None:
        raise RuntimeError(f"Empty jsonl: {dataset_path}")
    return best_item, best_len

def estimate_safe_batch_size_by_oom(
    extractor,
    dataset_path: str,
    max_new_tokens_target: int,
    cap: int = 150,
    safety_factor: float = 0.8,
    verbose: bool = True,
    oom_estimate: bool = False
) -> int:
    """
    估计最安全 batch_size：
    - 先找 prompt token 最长的样本
    - 估计时让 backend 在 prefill 侧额外拼接 max_new_tokens_target 个 token
    - generate 实际只跑 1 token（max_new_tokens=1）
    """
    if not torch.cuda.is_available():
        return 1

    cap = max(1, min(int(cap), 150))

    # 关键：用于长度统计的 tokenizer，尽量用 backend.tokenizer（像 qwen 会强制用 processor.tokenizer）
    tok = getattr(getattr(extractor, "backend", None), "tokenizer", None) or extractor.tokenizer

    worst_item, worst_len = find_worst_item_by_prompt_tokens(
        dataset_path=dataset_path,
        tokenizer=tok,
        chunk_size=256,
    )
    if verbose:
        print(f"[BS-EST] worst prompt token len = {worst_len}")

    # 探测时：prefill 拼长 + generate 1 token（quiet 避免刷屏）
    bs_estimate_gen_cfg = {"_prefill_extra_tokens": int(max_new_tokens_target)}
    # 可选：指定填充 token（默认用 eos）
    if getattr(tok, "eos_token_id", None) is not None:
        bs_estimate_gen_cfg["_prefill_token_id"] = int(tok.eos_token_id) if not isinstance(tok.eos_token_id, (list, tuple)) else int(tok.eos_token_id[0])

    def try_bs(bs: int) -> bool:
        _empty_cache_all()
        bs = int(bs)
        if bs <= 0:
            return False

        items = None
        out = None
        try:
            _empty_cache_all()
            items = [copy.deepcopy(worst_item) for _ in range(bs)]

            with torch.inference_mode():  # 比 no_grad 更适合纯推理
                torch.cuda.synchronize()
                out = extractor.backend.generate_batch(
                    items=items,
                    max_new_tokens=66,
                    gen_cfg={},
                    oom_estimate=oom_estimate,
                    bs_estimate_gen_cfg=bs_estimate_gen_cfg
                )
                torch.cuda.synchronize()

            return True

        except OOM_ERRORS:
            return False

        except RuntimeError as e:
            if _maybe_oom_runtime_error(e):
                return False
            raise

        finally:
            # 不管成功/失败都尽快断引用并清缓存
            try:
                del out
            except UnboundLocalError:
                pass
            del items
            _empty_cache_all()

    if verbose:
        print(f"[BS-EST] Binary search in [1, {cap}] ...")

    if not try_bs(1):
        if verbose:
            print("[BS-EST] even bs=1 OOM -> fallback 1")
        return 1

    if try_bs(cap):
        rough_max = cap
    else:
        low, high = 1, cap
        while low + 1 < high:
            mid = (low + high) // 2
            ok = try_bs(mid)
            if verbose:
                print(f"[BS-EST] bs={mid} {'OK' if ok else 'OOM'}")
            if ok:
                low = mid
            else:
                high = mid
        rough_max = low

    safe_bs = max(1, int(rough_max * float(safety_factor)))
    if verbose:
        print(f"[BS-EST] rough_max={rough_max}, safe_bs={safe_bs} (safety={safety_factor})")
    return safe_bs

def load_batch_size_csv(csv_path: str) -> Dict[Tuple[str, str], int]:
    """
    读取形如：model,dataset,safe_batch_size 的 CSV
    返回 {(model, dataset): safe_bs}
    - safe_batch_size 为空/非法会跳过
    - 遇到重复 key：取更小的那个（更保守，避免 OOM）
    """
    table: Dict[Tuple[str, str], int] = {}
    if not csv_path:
        return table
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"batch csv not found: {csv_path}")

    with open(csv_path, "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        if not reader.fieldnames:
            return table

        # 兼容列名：safe_batch_size / batch_size
        for row in reader:
            mk = (row.get("model") or "").strip()
            ds = (row.get("dataset") or "").strip()
            bs_raw = (row.get("safe_batch_size") or row.get("batch_size") or "").strip()
            if not mk or not ds or not bs_raw:
                continue
            try:
                bs = int(float(bs_raw))
            except Exception:
                continue
            if bs <= 0:
                continue

            key = (mk, ds)
            if key in table:
                table[key] = min(table[key], bs)  # 保守：取更小
            else:
                table[key] = bs
    return table


def main() -> None:
    """完全按照原始代码的main函数"""
    p = argparse.ArgumentParser("Hidden feature extractor with checkpoint")
    p.add_argument("--models_root", default=os.path.join(PROJECT_ROOT, "models"), help="本地模型根目录（用于拼接 模型名→路径）")
    p.add_argument("--data_dir", default=os.path.join(PROJECT_ROOT,"new_benchmark","seedbench_plus2"))
    p.add_argument("--output_dir", default=os.path.join(PROJECT_ROOT, "feats/seedbench_plus2/seedbench_plus2_feats"))
    p.add_argument("--device", default="cuda")
    p.add_argument("--model", default="ovis2-4b")
    p.add_argument("--dataset", default="seedbench_plus2_val", nargs="+", help="数据集名称列表（不含扩展名），默认使用 seedbench_plus2_val")
    p.add_argument("--use_few_shot", action="store_true")
    p.add_argument("--layer_type", default="all", choices=["quarter", "middle", "three_quarters", "last", "second_last", "first", "all"])
    p.add_argument("--batch_size", type=int, default=None, help="批量处理的batch size")
    p.add_argument("--resume", action="store_true", help="是否从断点恢复")
    p.add_argument("--checkpoint_root", type=str, default=os.path.join(PROJECT_ROOT, "checkpoints/seedbench_plus2/seedbench_plus2_feats"), help="checkpoint存储目录")
    # p.add_argument("--batch_csv", type=str, default="oom_bs_results.csv", help="CSV文件：model,dataset,safe_batch_size；若提供则优先用CSV中的batch_size")
    p.add_argument("--batch_csv", type=str, default="oom_bs_results.csv", help="CSV文件：model,dataset,safe_batch_size；若提供则优先用CSV中的batch_size")

    # ===== OOM batch size search switches =====
    p.add_argument("--oom_bs_search", type=bool, default=False, help="开启：在 worst sample 上二分 OOM 探测 batch_size（用于自动估计最大 batch）")
    p.add_argument("--oom_bs_cap", type=int, default=150, help="OOM 探测的 batch_size 上限（会再被 clamp 到 <=150）")
    p.add_argument("--oom_bs_safety", type=float, default=0.80, help="安全系数：最终 batch = int(max_ok * safety)")

    args = p.parse_args()

    bs_table: Dict[Tuple[str, str], int] = {}
    if args.batch_csv:
        bs_table = load_batch_size_csv(args.batch_csv)
        console.print(f"[cyan][BS-CSV] loaded {len(bs_table)} (model,dataset)->batch entries from {args.batch_csv}[/cyan]")

    os.makedirs(args.models_root, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)
    if args.dataset is None:
        args.dataset = ["seedbench_plus2_train", "seedbench_plus2_val", "seedbench_plus2_test"]
    if isinstance(args.dataset, str):
        args.dataset = [args.dataset]
    datasets = load_datasets(args.dataset, args.data_dir)

    if not datasets:
        console.print("[red] 数据集不存在[/red]")
        return

    if args.checkpoint_root is None:
        checkpoint_root = os.path.join(args.output_dir, "checkpoints")
    else:
        checkpoint_root = args.checkpoint_root
    os.makedirs(checkpoint_root, exist_ok=True)

    # models = MODELS.keys() if args.model == "all" else [args.model]
    if args.model == "all":
        models = []
        for d in sorted(os.listdir(args.models_root)):
            mp = os.path.join(args.models_root, d)
            if os.path.isdir(mp) and (
                os.path.exists(os.path.join(mp, "config.json")) or
                os.path.exists(os.path.join(mp, "tokenizer.json"))
            ):
                models.append(d)
    else:
        models = [m.strip() for m in str(args.model).split(",") if m.strip()]

    print(f"[info] 发现 {len(models)} 个模型: {models}")

    for mk in models:
        model_path = os.path.join(args.models_root, mk)

        if not os.path.exists(model_path):
            console.print(f"[red] {mk} 模型目录不存在：{model_path}[/red]")
            continue

        console.rule(f"📦  Processing model — {mk}")
        extractor = HiddenExtractor(mk, model_path, args.device)

        for dname, dpath in datasets.items():
            prompt_last_token_dir = os.path.join(args.output_dir, mk+"_prompt_last_token", dname)
            answer_first_token_dir = os.path.join(args.output_dir, mk+"_answer_first_token", dname)
            last_token_out_dir = os.path.join(args.output_dir, mk+"_last_token", dname)

            if args.batch_size is None:
                # 2) 优先：CSV 指定的 (model, dataset) batch_size
                if bs_table:
                    key = (mk, dname)
                    if key in bs_table:
                        effective_bsz = max(1, int(bs_table[key]))
                        console.print(f"[cyan][BS-CSV] {mk} / {dname}: batch_size = {effective_bsz}[/cyan]")
                    else:
                        console.print(f"[yellow][BS-CSV] {mk} / {dname}: not found in csv, fallback...[/yellow]")

                # 3) 回退：如果 CSV 没命中，且你开启了 OOM 探测，则用探测值覆盖
                if (not bs_table or (mk, dname) not in bs_table) and args.oom_bs_search:
                    effective_bsz = estimate_safe_batch_size_by_oom(
                        extractor=extractor,
                        dataset_path=dpath,
                        max_new_tokens_target=1024,
                        cap=args.oom_bs_cap,
                        safety_factor=args.oom_bs_safety,
                        verbose=True,
                        oom_estimate=True,
                    )
                    console.print(f"[yellow][BS-EST] {mk} / {dname}: effective_batch_size = {effective_bsz}")
            else:
                 # 设置批量大小
                effective_bsz = max(1, int(args.batch_size))

            extractor.extract_dataset(
                dataset_name=dname,
                dataset_path=dpath,
                prompt_last_token_dir=prompt_last_token_dir,
                answer_first_token_dir=answer_first_token_dir,
                last_token_outputdir=last_token_out_dir,
                layer_req=args.layer_type,
                max_new_tokens=1024,
                batch_size=effective_bsz,
                resume=args.resume,
                checkpoint_root=checkpoint_root,
                oom_estimate=False
            )
    console.rule("程序结束！")


if __name__ == "__main__":
    main()
