# build_new_dataset_with_insertion.py
# 把本地合成的 uint16 npy（形状 [M, 2049]）插入到原始 document1.{bin,idx} 中间，
# 只保留原始前 BASE_KEEP 条，最终写出新的 mmap 数据前缀（MMIDIDX，dtype=uint16）。
#
# 使用：
# 1) 修改下方常量（BASE_PREFIX / SYNTH_NPY / OUT_PREFIX 等）
# 2) python build_new_dataset_with_insertion.py
# 3) 用你之前的 inspect 脚本验证新前缀：impl=mmap, dtype=uint16, num_samples=3172000, sample_len=2049

import os
import sys
import time
import numpy as np

# ====== 固定配置（请按需修改） ======
BASE_PREFIX  = "/root/trainbin1/document1"            # 原始数据前缀（不带扩展名）
SYNTH_NPY    = "/root/trainbin1/all_tokens_merged.npy"  # 你生成的 npy（[M, 2049]，uint16）
OUT_PREFIX   = "/root/trainbin1/document2"  # 输出新前缀（不带扩展名）

INSERT_POS   = 9 * 102400      # 921600
BASE_KEEP    = 2560000         # 只保留原始前 3,072,000 条
EXPECTED_LEN = 2049            # 每条样本长度（= seq_length+1）
# ====================================

def _fail(msg):
    print("ERROR:", msg)
    sys.exit(1)

def main():
    try:
        from megatron.data.indexed_dataset import (
            make_dataset,
            infer_dataset_impl,
            MMapIndexedDatasetBuilder,
        )
    except Exception as e:
        _fail(f"无法导入 megatron.data.indexed_dataset，请在 GPT‑NeoX 工程环境下运行。detail: {e}")

    # 1) 加载原始数据集（必须是 mmap）
    impl = infer_dataset_impl(BASE_PREFIX)
    if impl is None:
        _fail(f"无法推断数据实现，检查路径：{BASE_PREFIX}")
    if impl != "mmap":
        print(f"WARNING: 检测到 impl={impl}（非 mmap）。将尝试继续，但建议使用 mmap。")

    print(f"[load] base: {BASE_PREFIX} (impl={impl})")
    base_ds = make_dataset(BASE_PREFIX, impl, skip_warmup=True)
    if base_ds is None:
        _fail("加载原始数据失败")

    base_n = len(base_ds)
    # 取 dtype
    base_dtype = getattr(getattr(base_ds, "_index", None), "dtype", None) or getattr(base_ds, "dtype", None)
    if base_dtype is None:
        # 兜底：读一条取 dtype
        base_dtype = np.asarray(base_ds[0]).dtype
    if np.dtype(base_dtype) != np.uint16:
        print(f"WARNING: base dtype = {np.dtype(base_dtype)}, 不是 uint16。将尝试按原 dtype 写出。")

    # 2) 加载合成样本 npy
    print(f"[load] synth npy: {SYNTH_NPY}")
    synth = np.load(SYNTH_NPY, mmap_mode="r")
    if synth.ndim != 2:
        _fail(f"synth 形状必须是 [M, T]，当前 {synth.shape}")
    M, T = synth.shape
    print(f"[synth] shape={synth.shape}, dtype={synth.dtype}")
    if EXPECTED_LEN and T != EXPECTED_LEN:
        _fail(f"synth 每条长度 T={T} != 期望 {EXPECTED_LEN}")
    # dtype 检查
    if np.dtype(synth.dtype) != np.dtype(base_dtype):
        # 安全转换（只要不越界）
        iinfo = np.iinfo(np.dtype(base_dtype))
        mn = int(np.min(synth))
        mx = int(np.max(synth))
        if mn < iinfo.min or mx > iinfo.max:
            _fail(f"synth token 超出目标 dtype 范围：[{mn},{mx}] vs {iinfo}")
        print(f"[synth] 转换 dtype: {synth.dtype} -> {np.dtype(base_dtype)}")
        synth = synth.astype(base_dtype, copy=False)

    # 3) 规则检查
    if BASE_KEEP > base_n:
        _fail(f"BASE_KEEP={BASE_KEEP} 超过原始总数 {base_n}")
    if not (0 <= INSERT_POS <= BASE_KEEP):
        _fail(f"INSERT_POS={INSERT_POS} 不在 [0, {BASE_KEEP}] 范围内")

    expected_total = BASE_KEEP + M
    print(f"[plan] keep base[0:{BASE_KEEP}), insert synth({M}) at pos={INSERT_POS} -> total={expected_total}")

    # 4) 写新前缀（mmapped）
    out_bin = OUT_PREFIX + ".bin"
    out_idx = OUT_PREFIX + ".idx"
    if os.path.exists(out_bin) or os.path.exists(out_idx):
        _fail(f"输出文件已存在：{out_bin} / {out_idx}，为避免覆盖，请先删或换 OUT_PREFIX")

    print(f"[write] -> {out_bin} / {out_idx} (dtype={np.dtype(base_dtype)})")
    target_dtype_type = np.dtype(base_dtype).type  # 例如 <class 'numpy.uint16'>
    builder = MMapIndexedDatasetBuilder(out_bin, dtype=target_dtype_type)

    # 写入 synth 的小工具
    def write_synth_block():
        t0 = time.time()
        for k in range(M):
            row = synth[k]
            if row.ndim != 1:
                row = row.reshape(-1)
            if row.shape[0] != T:
                _fail(f"synth[{k}] 长度 {row.shape[0]} != {T}")
            builder.add_item(row)
            builder.end_document()
            if (k + 1) % 10000 == 0:
                dt = time.time() - t0
                print(f"[synth] {k+1}/{M} rows written, {dt:.1f}s")
        print(f"[synth] done ({M} rows).")

    # 若 pos==0，先写 synth
    if INSERT_POS == 0:
        write_synth_block()

    # 依次写 base[0:BASE_KEEP)，在 INSERT_POS 处插入 synth
    t0 = time.time()
    inserted = False
    for i in range(BASE_KEEP):
        if (not inserted) and (i == INSERT_POS) and (INSERT_POS != 0):
            write_synth_block()
            inserted = True

        arr = np.asarray(base_ds[i])
        if arr.ndim != 1:
            arr = arr.reshape(-1)
        if arr.shape[0] != T:
            _fail(f"base[{i}] 长度 {arr.shape[0]} != {T}")
        # 确保 dtype 匹配
        if arr.dtype != base_dtype:
            arr = arr.astype(base_dtype, copy=False)
        builder.add_item(arr)
        builder.end_document()

        if (i + 1) % 100000 == 0:
            dt = time.time() - t0
            print(f"[base] {i+1}/{BASE_KEEP} rows written, {dt:.1f}s")

    # 如果 INSERT_POS==BASE_KEEP（追加尾部）且还没插入，最后再插
    if (INSERT_POS == BASE_KEEP) and (not inserted):
        write_synth_block()

    # 5) 完成索引
    builder.finalize(out_idx)
    print(f"[done] wrote: {out_bin} / {out_idx}")
    print(f"[stat] total={expected_total}, synth={M}, base_kept={BASE_KEEP}, insert_pos={INSERT_POS}")
    print("请用你的 inspect 脚本验证：impl=mmap, dtype=uint16, num_samples=3172000, sample_len=2049。")

if __name__ == "__main__":
    main()
