# tools/verify_perm_900_998.py
# -*- coding: utf-8 -*-
import pickle
import numpy as np
from pathlib import Path

# ===== 按需修改为你的实际配置 =====
DATA_PREFIX  = "/root/trainbin1/document1"      # 不带 .idx/.bin 的前缀
DATA_IMPL    = "mmap"                               # 一般是 mmap
SEQ_LENGTH   = 2048                                 # 训练用的 seq_length（与你的配置一致）
PERM_PATH    ="/root/trainbin1/neworder.npy"    # 生成的 perm.npy
PKL_PATH     = "/root/trainbin1/influence_copytarget_qkvo1123.pkl"        # 选中样本的 pkl
BASE_OFFSET  = 1024000                              # 你之前说的偏移
STEP_SIZE    = 1024
TARGET_START_STEP   = 900
PRESERVE_FROM_STEP  = 998  # 你说“换成998”，这里就保留 998 及以后不变
# =================================

def load_selected_from_pkl(pkl_path: str) -> np.ndarray:
    with open(pkl_path, "rb") as f:
        obj = pickle.load(f)

    def from_list(lst):
        if not isinstance(lst, (list, tuple)) or len(lst) == 0:
            return None
        if isinstance(lst[0], (int, np.integer)):
            return np.asarray(lst, dtype=np.int64)
        if isinstance(lst[0], dict):
            for key in ["sample_index", "index", "idx", "global_index"]:
                if key in lst[0]:
                    out = []
                    for d in lst:
                        if isinstance(d, dict) and key in d:
                            try:
                                out.append(int(d[key]))
                            except Exception:
                                pass
                    return np.asarray(out, dtype=np.int64)
        return None

    arr = None
    if isinstance(obj, (list, tuple)):
        arr = from_list(obj)
    if arr is None and isinstance(obj, dict):
        for k in ["positive_influencers", "positives", "positive_samples",
                  "selected_positive_indices", "selected_indices", "indices", "items", "data"]:
            if k in obj:
                arr = from_list(obj[k])
                if arr is not None:
                    break
        if arr is None and "analysis_results" in obj and isinstance(obj["analysis_results"], dict):
            ar = obj["analysis_results"]
            for k in ["positive_influencers", "indices", "items", "data"]:
                if k in ar:
                    arr = from_list(ar[k])
                    if arr is not None:
                        break

    if arr is None:
        raise ValueError("无法从 PKL 解析出索引，请调整 load_selected_from_pkl() 的解析逻辑。")
    return arr.astype(np.int64, copy=False)

def main():
    perm = np.load(PERM_PATH)
    N = perm.shape[0]
    start = TARGET_START_STEP * STEP_SIZE
    tail  = PRESERVE_FROM_STEP * STEP_SIZE
    capacity = max(0, min(N, tail) - min(N, start))

    print(f"[info] perm.shape={perm.shape}, N={N}, window=[{start},{tail}), cap={capacity}")

    # 1) 检查“尾部不变”是否成立（i >= tail 时 perm[i] == i）
    tail_ok = np.array_equal(perm[tail:], np.arange(tail, N, dtype=np.int64))
    print(f"[check] 尾部保持不变（step >= {PRESERVE_FROM_STEP}）: {tail_ok}")

    # 2) 载入 pkl -> 选中样本全局索引（加偏移），过滤掉不在 [0, tail) 的（尾部原位不动）
    sel_local = load_selected_from_pkl(PKL_PATH)
    seen = set()
    selected = []
    for i in sel_local:
        gi = int(i) + int(BASE_OFFSET)
        if 0 <= gi < tail and gi not in seen:
            selected.append(gi); seen.add(gi)
    selected = np.asarray(selected, dtype=np.int64)
    print(f"[info] pkl选中条数(过滤尾部/去重后)={len(selected)}")

    # 3) 检查窗口前缀是否正好是选中样本（按 pkl 顺序，截断到容量）
    take = min(len(selected), capacity)
    prefix_ok = np.array_equal(perm[start:start+take], selected[:take])
    print(f"[check] 窗口前缀 == 选中样本（数量={take}）: {prefix_ok}")

    # 4) 如果选中样本数 <= 容量，进一步检查：所有选中样本都在窗口内
    within_ok = True
    if len(selected) <= capacity:
        inv = np.empty(N, dtype=np.int64)
        inv[perm] = np.arange(N, dtype=np.int64)  # 元素->位置
        pos = inv[selected]
        within_ok = bool(np.all((pos >= start) & (pos < tail)))
        print(f"[check] 全部选中样本位于窗口 [{start},{tail}) 内: {within_ok}")
    else:
        print(f"[warn] 选中样本({len(selected)})超过窗口容量({capacity})，可能有溢出放到前缀区域；跳过此检查。")

    # 5) 抽查内容一致性：对若干位置验证 ds[i]['text'] == base[perm[i]]['text']
    #    需要按你训练时的方式构造 MaskedGPT2Dataset
    from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
    from megatron.data.masked_gpt2_dataset import MaskedGPT2Dataset
    try:
        from megatron.data.permutation_dataset import PermutationDataset
    except Exception:
        # 如果你还没写 wrapper，就临时定义一个等价的最小版本
        from torch.utils.data import Dataset
        class PermutationDataset(Dataset):
            def __init__(self, base_ds, order):
                self.base = base_ds
                self.order = np.asarray(order, dtype=np.int64)
            def __len__(self): return len(self.order)
            def __getitem__(self, i): return self.base[int(self.order[i])]

    indexed = make_indexed_dataset(DATA_PREFIX, DATA_IMPL, skip_warmup=True)
    assert len(indexed) == N, f"数据集大小 {len(indexed)} 与 perm 长度 {N} 不一致"
    documents = np.arange(N, dtype=np.int32)
    base = MaskedGPT2Dataset(
        name="train_verify",
        data_prefix=DATA_PREFIX,
        documents=documents,
        indexed_dataset=indexed,
        num_samples=N,
        seq_length=SEQ_LENGTH,
        seed=1234,
        build_index_mappings=True,
        mask_npy_path=None,
        mask_start_global_idx=0,
    )
    ds = PermutationDataset(base, perm)

    probes = [start, start+1, start+123, tail-1, tail, min(tail+123, N-1), N-1]
    probes = [i for i in probes if 0 <= i < N]
    ok_cnt = 0
    for i in probes:
        a = ds[i]["text"]
        b = base[int(perm[i])]["text"]
        same = np.array_equal(a, b)
        print(f"[probe] i={i} -> perm[i]={int(perm[i])} | match={same}")
        ok_cnt += int(same)
    print(f"[result] 内容抽查通过 {ok_cnt}/{len(probes)}")

    all_ok = tail_ok and prefix_ok and (within_ok if len(selected) <= capacity else True) and (ok_cnt == len(probes))
    print(f"\n结论：{'验证通过（置换与内容均符合预期）' if all_ok else '验证未通过，请检查 perm 生成或参数设置'}")

if __name__ == "__main__":
    main()

