# tools/build_order_insert_shift.py
# -*- coding: utf-8 -*-
import os, struct, pickle
import numpy as np
from pathlib import Path

# ===== 修改为你的路径/参数 =====
PKL_PATH    = r"influence_copytarget_qkvo124_L3H3.pkl"  # 特定样本 pkl（相对索引）
BASE_OFFSET = 0                            # 把相对索引平移为全局索引
DATASET_SIZE = None                              # 不填则从 DATA_PREFIX.idx 自动探测
DATA_PREFIX  = r"document1"    # 不带后缀，DATASET_SIZE=None 时必填
STEP_SIZE   = 1024
START_STEP  = 200                                # 插入起始 step
ORDER_OUT   = r"\.npy"
# =================================

def detect_size(prefix: str) -> int:
    idx = prefix + ".idx"
    with open(idx, "rb") as f:
        magic9 = f.read(9)
        if magic9 == b"MMIDIDX\x00\x00":
            _ = f.read(8); _ = f.read(1)
            n = struct.unpack("<Q", f.read(8))[0]
            return int(n)
        f.seek(0)
        if f.read(8) != b"TNTIDX\x00\x00":
            raise ValueError("unknown idx format")
        _ = struct.unpack("<Q", f.read(8))[0]; f.read(16)
        n = struct.unpack("<Q", f.read(8))[0]
        return int(n)

def load_selected(pkl_path: str) -> np.ndarray:
    with open(pkl_path, "rb") as f:
        obj = pickle.load(f)
    def from_list(lst):
        if not isinstance(lst, (list, tuple)) or len(lst)==0: return None
        if isinstance(lst[0], (int, np.integer)): return np.asarray(lst, dtype=np.int64)
        if isinstance(lst[0], dict):
            for k in ["sample_index","index","idx","global_index"]:
                if k in lst[0]:
                    out = []
                    for d in lst:
                        if isinstance(d, dict) and k in d:
                            try: out.append(int(d[k]))
                            except: pass
                    return np.asarray(out, dtype=np.int64)
        return None
    arr = None
    if isinstance(obj, (list,tuple)): arr = from_list(obj)
    if arr is None and isinstance(obj, dict):
        for k in ["positive_influencers","positives","positive_samples",
                  "selected_positive_indices","selected_indices","indices","items","data"]:
            if k in obj:
                arr = from_list(obj[k]); 
                if arr is not None: break
        if arr is None and "analysis_results" in obj and isinstance(obj["analysis_results"], dict):
            ar = obj["analysis_results"]
            for k in ["positive_influencers","indices","items","data"]:
                if k in ar:
                    arr = from_list(ar[k]); 
                    if arr is not None: break
    if arr is None:
        raise ValueError("无法从 PKL 解析出索引，请按实际结构调整解析")
    return arr.astype(np.int64, copy=False)

def main():
    # N
    if DATASET_SIZE is None:
        if not DATA_PREFIX: raise ValueError("DATASET_SIZE=None 时需要设置 DATA_PREFIX")
        N = detect_size(DATA_PREFIX)
    else:
        N = int(DATASET_SIZE)
    print(f"[info] N={N}")

    # 选中样本（加偏移、去重）
    sel_local = load_selected(PKL_PATH)
    seen=set(); selected=[]
    for i in sel_local:
        gi = int(i) + int(BASE_OFFSET)
        if 0 <= gi < N and gi not in seen:
            selected.append(gi); seen.add(gi)
    selected = np.asarray(selected, dtype=np.int64)
    K = int(selected.size)
    print(f"[info] selected(after offset/dedup)={K}")

    start = START_STEP * STEP_SIZE
    if start > N: 
        raise ValueError(f"START_STEP 太大：start={start} > N={N}")

    # 构造新顺序： [0..start-1] + selected + [start..N-1]
    pre = np.arange(start, dtype=np.int64)
    tail = np.arange(start, N, dtype=np.int64)
    order = np.concatenate([pre, selected, tail], axis=0)

    # 校验：长度 N+K；被插入块正好是 selected；start+K 后是原 start..N-1
    ok_len = (len(order) == N + K)
    ok_blk = np.array_equal(order[start:start+K], selected)
    ok_tail = np.array_equal(order[start+K:], np.arange(start, N, dtype=np.int64))
    print(f"[check] len={ok_len}, block={ok_blk}, shifted_tail={ok_tail}")

    Path(ORDER_OUT).parent.mkdir(parents=True, exist_ok=True)
    np.save(ORDER_OUT, order)
    print(f"[done] wrote: {ORDER_OUT}")

if __name__ == "__main__":
    main()

