# tools/build_order_insert_shift.py
# -*- coding: utf-8 -*-
import os, struct, pickle
import numpy as np
from pathlib import Path

# ===== 修改为你的路径/参数 =====
PKL_PATH    = r"influence_copytarget_qkvo124_L3H3.pkl"  # 特定样本 pkl（相对索引）
BASE_OFFSET = 1024000                            # 把相对索引平移为全局索引
DATASET_SIZE = None                              # 不填则从 DATA_PREFIX.idx 自动探测
DATA_PREFIX  = r"trainbin1\document1"  # 不带后缀，DATASET_SIZE=None 时必填

STEP_SIZE   = 1024
START_STEP  = 900                                # 插入区间起始 step（含）
END_STEP    = 1400                               # 插入区间结束 step（不含）
ORDER_OUT   = r""

# 只插入按 PKL 原始顺序的前 K 个元素；设为 None 表示全部插入
INSERT_TOP_K = 12500

# 随机参数：RANDOM_SEED 可复现；是否打乱 selected 本身（默认不打乱，保持 PKL 原顺序）
RANDOM_SEED = 42
SHUFFLE_SELECTED = False
# =================================

def detect_size(prefix: str) -> int:
    idx = prefix + ".idx"
    with open(idx, "rb") as f:
        magic9 = f.read(9)
        if magic9 == b"MMIDIDX\x00\x00":
            _ = f.read(8); _ = f.read(1)
            n = struct.unpack("<Q", f.read(8))[0]
            return int(n)
        f.seek(0)
        if f.read(8) != b"TNTIDX\x00\x00":
            raise ValueError("unknown idx format")
        _ = struct.unpack("<Q", f.read(8))[0]; f.read(16)
        n = struct.unpack("<Q", f.read(8))[0]
        return int(n)

def load_selected(pkl_path: str) -> np.ndarray:
    with open(pkl_path, "rb") as f:
        obj = pickle.load(f)

    def from_list(lst):
        if not isinstance(lst, (list, tuple)) or len(lst)==0: return None
        if isinstance(lst[0], (int, np.integer)): return np.asarray(lst, dtype=np.int64)
        if isinstance(lst[0], dict):
            for k in ["sample_index","index","idx","global_index"]:
                if k in lst[0]:
                    out = []
                    for d in lst:
                        if isinstance(d, dict) and k in d:
                            try: out.append(int(d[k]))
                            except: pass
                    return np.asarray(out, dtype=np.int64)
        return None

    arr = None
    if isinstance(obj, (list,tuple)): arr = from_list(obj)
    if arr is None and isinstance(obj, dict):
        for k in ["positive_influencers","positives","positive_samples",
                  "selected_positive_indices","selected_indices","indices","items","data"]:
            if k in obj:
                arr = from_list(obj[k])
                if arr is not None: break
        if arr is None and "analysis_results" in obj and isinstance(obj["analysis_results"], dict):
            ar = obj["analysis_results"]
            for k in ["positive_influencers","indices","items","data"]:
                if k in ar:
                    arr = from_list(ar[k])
                    if arr is not None: break
    if arr is None:
        raise ValueError("无法从 PKL 解析出索引，请按实际结构调整解析")
    return arr.astype(np.int64, copy=False)

def main():
    # N：数据集样本总数
    if DATASET_SIZE is None:
        if not DATA_PREFIX:
            raise ValueError("DATASET_SIZE=None 时需要设置 DATA_PREFIX")
        N = detect_size(DATA_PREFIX)
    else:
        N = int(DATASET_SIZE)
    print(f"[info] N={N}")

    # 选中样本（保持 PKL 顺序，添加偏移，去重，裁剪前 K 个）
    sel_local = load_selected(PKL_PATH)
    seen=set(); selected=[]
    for i in sel_local:                 # 保留 PKL 原顺序
        gi = int(i) + int(BASE_OFFSET)  # 偏移到全局索引
        if 0 <= gi < N and gi not in seen:
            selected.append(gi); seen.add(gi)
    if INSERT_TOP_K is not None:
        selected = selected[:int(INSERT_TOP_K)]
    selected = np.asarray(selected, dtype=np.int64)
    K = int(selected.size)
    print(f"[info] selected(after offset/dedup/topK)={K}")

    # 计算随机插入区间 [start, end)，按样本索引定位；自动截断到 [0, N]
    start = int(START_STEP) * int(STEP_SIZE)
    end   = int(END_STEP)   * int(STEP_SIZE)

    # 边界处理与提示
    if start < 0:
        print(f"[warn] START_STEP*STEP_SIZE < 0，已截断到 0（原 start={start})")
        start = 0
    if end > N:
        print(f"[warn] END_STEP*STEP_SIZE 超过数据集长度，已截断到 N（原 end={end}, N={N})")
        end = N
    if start > N:
        print(f"[warn] START_STEP*STEP_SIZE 超过数据集长度，已截断到 N（原 start={start}, N={N})")
        start = N
    if end < start:
        print(f"[warn] 插入区间为空或非法，调整为在 start 处集中插入（start={start}, end={end})")
        end = start

    # 可选：是否打乱要插入的 selected 的顺序（默认不打乱，保持 PKL 原顺序）
    rng = np.random.default_rng(int(RANDOM_SEED))
    if SHUFFLE_SELECTED and K > 0:
        rng.shuffle(selected)

    # 切分原序列
    pre  = np.arange(0, start, dtype=np.int64)
    mid  = np.arange(start, end, dtype=np.int64)
    tail = np.arange(end, N, dtype=np.int64)
    M = int(mid.size)

    # 在 mid 段的 (M+1) 个“缝隙”里均匀随机选择 K 个插入位置（可重复落在同一缝）
    if K > 0:
        slots = rng.integers(0, M + 1, size=K)  # 0..M，值 pos 表示在 mid[pos] 之前插入；M 表示放在 mid 末尾
        counts = np.bincount(slots, minlength=M + 1)
    else:
        counts = np.zeros(M + 1, dtype=np.int64)

    # 合成：保持 mid 原顺序不变，把 selected 按随机位置分散插入
    merged = []
    cur = 0
    for pos in range(M + 1):
        c = int(counts[pos])
        if c > 0:
            merged.extend(selected[cur:cur + c])
            cur += c
        if pos < M:
            merged.append(mid[pos])
    merged = np.asarray(merged, dtype=np.int64)

    order = np.concatenate([pre, merged, tail], axis=0)

    # 基本校验
    ok_len = (len(order) == N + K)
    ok_pre = np.array_equal(order[:start], pre)
    ok_tail = np.array_equal(order[-len(tail):] if len(tail)>0 else np.array([], dtype=np.int64), tail)
    print(f"[check] len={ok_len}, pre_ok={ok_pre}, tail_ok={ok_tail}, "
          f"insert_range_steps=[{START_STEP},{END_STEP}), "
          f"start={start}, end={end}, M={M}, K={K}")

    Path(ORDER_OUT).parent.mkdir(parents=True, exist_ok=True)
    np.save(ORDER_OUT, order)
    print(f"[done] wrote: {ORDER_OUT}")

if __name__ == "__main__":
    main()
