# genflexbin_uniform.py
# 将合成样本平均分散插入到指定 step 区间 [START_STEP, END_STEP) 内
# 单点插入逻辑已完全移除；每个 step = STEP_UNIT 个样本位置（默认 1024）

import os
import sys
import time
import numpy as np

# ====== 配置 ======
BASE_PREFIX  = "/root/trainbin1/document1"             # 原始数据前缀（不带扩展名）
SYNTH_NPY    = "/root/trainbin1/all_tokens_merged.npy" # 合成样本 npy（形状 [M, T]）
OUT_PREFIX   = "/root/trainbin1/document12500"           # 输出新前缀（不带扩展名）

BASE_KEEP    = 2_560_000        # 只保留原始前 BASE_KEEP 条（通常等于原数据集大小或其子集）
EXPECTED_LEN = 2049             # 每条样本长度（= seq_length+1）

# 采样配置：若为 None 则全部插入；否则从合成样本中取 K 条
SYNTH_SAMPLE_K       = 12500
SYNTH_SAMPLE_SEED    = 42
SYNTH_SAMPLE_REPLACE = False

# 区间平均插入（必填：起始 step；可选：结束 step）
STEP_UNIT       = 1024           # 每 step 对应的样本位置增量
START_STEP      = 900            # 起始 step（含）
END_STEP        = 1400           # 结束 step（不含）；None 表示到 BASE_KEEP 末尾
# ===================

def _fail(msg):
    print("ERROR:", msg)
    sys.exit(1)

def main():
    try:
        from megatron.data.indexed_dataset import (
            make_dataset,
            infer_dataset_impl,
            MMapIndexedDatasetBuilder,
        )
    except Exception as e:
        _fail(f"无法导入 megatron.data.indexed_dataset，请在 GPT‑NeoX 工程环境下运行。detail: {e}")

    # 1) 加载原始数据集
    impl = infer_dataset_impl(BASE_PREFIX)
    if impl is None:
        _fail(f"无法推断数据实现，检查路径：{BASE_PREFIX}")
    print(f"[load] base: {BASE_PREFIX} (impl={impl})")
    base_ds = make_dataset(BASE_PREFIX, impl, skip_warmup=True)
    if base_ds is None:
        _fail("加载原始数据失败")

    base_n = len(base_ds)
    if BASE_KEEP > base_n:
        _fail(f"BASE_KEEP={BASE_KEEP} 超过原始总数 {base_n}")

    # 推断 dtype
    base_dtype = getattr(getattr(base_ds, "_index", None), "dtype", None) or getattr(base_ds, "dtype", None)
    if base_dtype is None:
        base_dtype = np.asarray(base_ds[0]).dtype

    # 2) 加载合成样本
    print(f"[load] synth npy: {SYNTH_NPY}")
    synth = np.load(SYNTH_NPY, mmap_mode="r")
    if synth.ndim != 2:
        _fail(f"synth 形状必须是 [M, T]，当前 {synth.shape}")
    M, T = synth.shape
    print(f"[synth] shape={synth.shape}, dtype={synth.dtype}")
    if EXPECTED_LEN and T != EXPECTED_LEN:
        _fail(f"synth 每条长度 T={T} != 期望 {EXPECTED_LEN}")

    # dtype 对齐
    if np.dtype(synth.dtype) != np.dtype(base_dtype):
        iinfo = np.iinfo(np.dtype(base_dtype))
        mn = int(np.min(synth)); mx = int(np.max(synth))
        if mn < iinfo.min or mx > iinfo.max:
            _fail(f"synth token 超出目标 dtype 范围：[{mn},{mx}] vs {iinfo}")
        print(f"[synth] 转换 dtype: {synth.dtype} -> {np.dtype(base_dtype)}")
        synth = synth.astype(base_dtype, copy=False)

    # 2.5) 合成样本采样
    if SYNTH_SAMPLE_K is None:
        sel_idx = np.arange(M)
    else:
        K = int(SYNTH_SAMPLE_K)
        if K < 0:
            _fail(f"SYNTH_SAMPLE_K 必须为非负整数，当前 {SYNTH_SAMPLE_K}")
        rng = np.random.default_rng(SYNTH_SAMPLE_SEED)
        if SYNTH_SAMPLE_REPLACE:
            sel_idx = rng.integers(0, M, size=K)
        else:
            if K > M:
                print(f"[warn] 采样 K={K} 大于可用样本 M={M}，将使用 K=M")
                K = M
            sel_idx = rng.choice(M, size=K, replace=False)
            sel_idx.sort()
    S = int(sel_idx.shape[0])
    print(f"[sample] use {S} / {M} synth rows (seed={SYNTH_SAMPLE_SEED}, replace={SYNTH_SAMPLE_REPLACE})")

    # 3) step 区间 -> 样本位置区间 [start_pos, end_pos)
    start_pos = int(max(0, START_STEP) * int(STEP_UNIT))
    if END_STEP is None:
        end_pos = int(BASE_KEEP)
    else:
        end_pos = int(max(0, END_STEP) * int(STEP_UNIT))
    # 截断
    start_pos = max(0, min(start_pos, BASE_KEEP))
    end_pos   = max(0, min(end_pos,   BASE_KEEP))
    if end_pos < start_pos:
        _fail(f"平均插入区间非法：start_pos={start_pos} > end_pos={end_pos}")
    print(f"[uniform] steps=[{START_STEP},{'end' if END_STEP is None else END_STEP}), "
          f"unit={STEP_UNIT}, positions=[{start_pos},{end_pos})")

    expected_total = BASE_KEEP + S
    print(f"[plan] keep base[0:{BASE_KEEP}), uniformly spread synth({S}) in [{start_pos},{end_pos}) -> total={expected_total}")

    # 4) 写新前缀
    out_bin = OUT_PREFIX + ".bin"
    out_idx = OUT_PREFIX + ".idx"
    if os.path.exists(out_bin) or os.path.exists(out_idx):
        _fail(f"输出文件已存在：{out_bin} / {out_idx}，为避免覆盖，请先删或换 OUT_PREFIX")

    print(f"[write] -> {out_bin} / {out_idx} (dtype={np.dtype(base_dtype)})")
    target_dtype_type = np.dtype(base_dtype).type
    builder = MMapIndexedDatasetBuilder(out_bin, dtype=target_dtype_type)

    def _write_one_synth(k):
        row = synth[k]
        if row.ndim != 1:
            row = row.reshape(-1)
        if row.shape[0] != T:
            _fail(f"synth[{k}] 长度 {row.shape[0]} != {T}")
        builder.add_item(row)
        builder.end_document()

    t0 = time.time()

    # 均匀分配 S 到 M+1 个缝（M=end_pos-start_pos；最后一个缝是区间末尾后）
    M_range = int(end_pos - start_pos)
    slots = M_range + 1
    base_each, remainder = divmod(S, slots)
    counts = np.full(slots, base_each, dtype=np.int64)
    if remainder > 0:
        counts[:remainder] += 1  # 尽量平均，最多相差 1

    synth_ptr = 0

    # [0, start_pos)
    for i in range(0, start_pos):
        arr = np.asarray(base_ds[i]).reshape(-1)
        if arr.shape[0] != T:
            _fail(f"base[{i}] 长度 {arr.shape[0]} != {T}")
        if arr.dtype != base_dtype:
            arr = arr.astype(base_dtype, copy=False)
        builder.add_item(arr)
        builder.end_document()
        if (i + 1) % 100000 == 0:
            dt = time.time() - t0
            print(f"[base-pre] {i+1}/{BASE_KEEP} rows written, {dt:.1f}s")

    # [start_pos, end_pos)
    for i in range(start_pos, end_pos):
        c = int(counts[i - start_pos])
        for _ in range(c):
            if synth_ptr < S:
                _write_one_synth(sel_idx[synth_ptr])
                synth_ptr += 1
        arr = np.asarray(base_ds[i]).reshape(-1)
        if arr.shape[0] != T:
            _fail(f"base[{i}] 长度 {arr.shape[0]} != {T}")
        if arr.dtype != base_dtype:
            arr = arr.astype(base_dtype, copy=False)
        builder.add_item(arr)
        builder.end_document()
        if (i + 1) % 100000 == 0:
            dt = time.time() - t0
            print(f"[base-mid] {i+1}/{BASE_KEEP} rows written, {dt:.1f}s")

    # 区间末尾后的尾缝
    tail_c = int(counts[-1])
    for _ in range(tail_c):
        if synth_ptr < S:
            _write_one_synth(sel_idx[synth_ptr])
            synth_ptr += 1

    # [end_pos, BASE_KEEP)
    for i in range(end_pos, BASE_KEEP):
        arr = np.asarray(base_ds[i]).reshape(-1)
        if arr.shape[0] != T:
            _fail(f"base[{i}] 长度 {arr.shape[0]} != {T}")
        if arr.dtype != base_dtype:
            arr = arr.astype(base_dtype, copy=False)
        builder.add_item(arr)
        builder.end_document()
        if (i + 1) % 100000 == 0:
            dt = time.time() - t0
            print(f"[base-tail] {i+1}/{BASE_KEEP} rows written, {dt:.1f}s")

    # 兜底（理论上不会进来）
    while synth_ptr < S:
        _write_one_synth(sel_idx[synth_ptr])
        synth_ptr += 1

    # 完成索引
    builder.finalize(out_idx)
    print(f"[done] wrote: {out_bin} / {out_idx}")
    print(f"[stat] total={expected_total}, synth_written={S}, base_kept={BASE_KEEP}")
    print(f"[stat] uniform_slots={slots}, avg_per_slot≈{S/float(slots):.3f}")

    print("请用你的 inspect 脚本验证：impl、dtype/num_samples/sample_len 等是否一致。")

if __name__ == "__main__":
    main()
