# make_pos_top100k_mask.py
# 读取影响分析的 pkl，取正向前 10 万样本索引，生成布尔掩码 .npy（True=保留，False=屏蔽）

import os
import pickle
import numpy as np
from typing import Any, Dict, List, Tuple

# ======== 配置（请按需修改） ========
PKL_FILE_PATH =r"influence_copytarget_qkvo124_L3H3.pkl"
NPY_PATH = r"\.npy"   # 若提供，则用其行数作为掩码长度；否则走 ORIG_LEN 或 pkl.config
ORIG_LEN = 100000                         # 兜底长度（当找不到 NPY 时）
TOP_K_POS = 80000                         # 正向前 K
OUTPUT_MASK_PATH = r".npy"

# ======== 工具 ========
def ensure_dir_for_file(path: str):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)

def load_pickle(path: str) -> Dict[str, Any]:
    with open(path, "rb") as f:
        return pickle.load(f)

def norm_items_from_list_dict(items: List[Dict[str, Any]]) -> List[Tuple[int, float]]:
    """
    将 list[dict] 归一化为 (sample_index, score) 列表。
    兼容:
      - 新: sample_index, projection_score
      - 旧: sample_index_original_dataset, total_influence
      - 备用: original_index, score
    """
    out = []
    for it in (items or []):
        idx = it.get("sample_index")
        if idx is None:
            idx = it.get("sample_index_original_dataset")
        if idx is None:
            idx = it.get("original_index")
        score = it.get("projection_score")
        if score is None:
            score = it.get("total_influence", it.get("score"))
        if idx is None or score is None:
            continue
        out.append((int(idx), float(score)))
    return out

def pick_topk_positive_indices_from_pkl(d: Dict[str, Any], k: int) -> List[int]:
    # 优先使用新字段 positive_influencers
    pos_raw = d.get("positive_influencers", None)
    if isinstance(pos_raw, list) and len(pos_raw) > 0:
        items = norm_items_from_list_dict(pos_raw)
        # 正向按分数降序
        items_sorted = sorted(items, key=lambda x: x[1], reverse=True)
        chosen = [idx for idx, _ in items_sorted[:k]]
        # 去重保持顺序
        seen, uniq = set(), []
        for x in chosen:
            if x not in seen:
                seen.add(x)
                uniq.append(x)
        return uniq

    # 兼容旧结构：只有 top_k_influence（无正/负之分），这里按分数降序取前k（注意缺少“正向”的概念）
    old = d.get("top_k_influence", None)
    if isinstance(old, list) and len(old) > 0:
        items = norm_items_from_list_dict(old)
        items_sorted = sorted(items, key=lambda x: x[1], reverse=True)
        chosen = [idx for idx, _ in items_sorted[:k]]
        seen, uniq = set(), []
        for x in chosen:
            if x not in seen:
                seen.add(x)
                uniq.append(x)
        print("警告: pkl 不含 positive_influencers，使用旧字段 top_k_influence 代替（无正/负区分）。")
        return uniq

    raise RuntimeError("无法从 PKL 中解析正向 Top-K（缺少 positive_influencers / top_k_influence 字段）。")

def get_total_len(pkl_data: Dict[str, Any]) -> int:
    # 优先 NPY 路径
    if NPY_PATH and os.path.exists(NPY_PATH):
        arr = np.load(NPY_PATH, mmap_mode="r")
        return int(arr.shape[0])

    # 其次 pkl.config.NUM_TRAIN_SAMPLES
    cfg = pkl_data.get("config", {})
    if "NUM_TRAIN_SAMPLES" in cfg:
        try:
            return int(cfg["NUM_TRAIN_SAMPLES"])
        except Exception:
            pass

    # 兜底 ORIG_LEN
    return int(ORIG_LEN)

# ======== 主逻辑 ========
def main():
    # 校验与准备
    if not os.path.exists(PKL_FILE_PATH):
        print(f"错误: 找不到 PKL 文件: {PKL_FILE_PATH}")
        return
    ensure_dir_for_file(OUTPUT_MASK_PATH)

    print(f"读取 PKL: {PKL_FILE_PATH}")
    try:
        data = load_pickle(PKL_FILE_PATH)
    except Exception as e:
        print(f"加载 PKL 出错: {e}")
        return

    total_len = get_total_len(data)
    print(f"掩码长度 = {total_len}")
    print(f"将屏蔽正向前 {TOP_K_POS} 条样本 (False=屏蔽, True=保留)")

    # 取 Top-K 正向索引
    try:
        chosen = pick_topk_positive_indices_from_pkl(data, TOP_K_POS)
    except Exception as e:
        print(f"解析 Top-K 正向索引出错: {e}")
        return

    if not chosen:
        print("错误: 未从 PKL 中提取到任何正向索引。")
        return

    # 越界检查
    mx = max(chosen)
    if mx >= total_len:
        print(f"错误: 索引越界，最大索引={mx} ≥ 掩码长度={total_len}")
        return

    # 生成布尔掩码：默认 True；正向前K置 False
    mask = np.ones(total_len, dtype=bool)
    mask[np.array(chosen, dtype=np.int64)] = False

    # 保存
    np.save(OUTPUT_MASK_PATH, mask)
    print(f"已保存掩码: {OUTPUT_MASK_PATH}")
    print(f"统计: 长度={len(mask)}, True(保留)={int(mask.sum())}, False(屏蔽)={int((~mask).sum())}")

    # 可选：也保存被屏蔽索引列表，便于检查（注释掉即可）
    idx_list_path = OUTPUT_MASK_PATH.replace(".npy", "_masked_indices.txt")
    try:
        with open(idx_list_path, "w", encoding="utf-8") as f:
            for idx in chosen:
                f.write(f"{idx}\n")
        print(f"已导出被屏蔽索引列表: {idx_list_path}")
    except Exception as e:
        print(f"导出被屏蔽索引列表失败: {e}")

if __name__ == "__main__":
    main()
