import os
import sys
import time
import numpy as np

# ====== 固定配置（请按需修改） ======
BASE_PREFIX  = "/root/trainbin1/document1"            # 原始数据前缀（不带扩展名）
SYNTH_NPY    = "/root/trainbin1/all_tokens_merged.npy"  # 你生成的 npy（[M, 2049]，uint16）
OUT_PREFIX   = "/root/trainbin1/document5"  # 输出新前缀（不带扩展名）

INSERT_POS   = 9 * 102400      # 921600(31对应8，70对应7，160对应6)
BASE_KEEP    = 2560000         # 只保留原始前 3,072,000 条
EXPECTED_LEN = 2049            # 每条样本长度（= seq_length+1）

# 采样配置：若为 None 则全部插入；否则随机采样 K 条
SYNTH_SAMPLE_K       = 12500     # 例如 50000；None 表示使用全部
SYNTH_SAMPLE_SEED    = 42       # 采样随机种子
SYNTH_SAMPLE_REPLACE = False    # 是否有放回采样（多数情况下使用 False）
# ====================================

def _fail(msg):
    print("ERROR:", msg)
    sys.exit(1)

def main():
    try:
        from megatron.data.indexed_dataset import (
            make_dataset,
            infer_dataset_impl,
            MMapIndexedDatasetBuilder,
        )
    except Exception as e:
        _fail(f"无法导入 megatron.data.indexed_dataset，请在 GPT‑NeoX 工程环境下运行。detail: {e}")

    # 1) 加载原始数据集（必须是 mmap）
    impl = infer_dataset_impl(BASE_PREFIX)
    if impl is None:
        _fail(f"无法推断数据实现，检查路径：{BASE_PREFIX}")
    if impl != "mmap":
        print(f"WARNING: 检测到 impl={impl}（非 mmap）。将尝试继续，但建议使用 mmap。")

    print(f"[load] base: {BASE_PREFIX} (impl={impl})")
    base_ds = make_dataset(BASE_PREFIX, impl, skip_warmup=True)
    if base_ds is None:
        _fail("加载原始数据失败")

    base_n = len(base_ds)
    # 取 dtype
    base_dtype = getattr(getattr(base_ds, "_index", None), "dtype", None) or getattr(base_ds, "dtype", None)
    if base_dtype is None:
        # 兜底：读一条取 dtype
        base_dtype = np.asarray(base_ds[0]).dtype
    if np.dtype(base_dtype) != np.uint16:
        print(f"WARNING: base dtype = {np.dtype(base_dtype)}, 不是 uint16。将尝试按原 dtype 写出。")

    # 2) 加载合成样本 npy
    print(f"[load] synth npy: {SYNTH_NPY}")
    synth = np.load(SYNTH_NPY, mmap_mode="r")
    if synth.ndim != 2:
        _fail(f"synth 形状必须是 [M, T]，当前 {synth.shape}")
    M, T = synth.shape
    print(f"[synth] shape={synth.shape}, dtype={synth.dtype}")
    if EXPECTED_LEN and T != EXPECTED_LEN:
        _fail(f"synth 每条长度 T={T} != 期望 {EXPECTED_LEN}")
    # dtype 检查
    if np.dtype(synth.dtype) != np.dtype(base_dtype):
        # 安全转换（只要不越界）
        iinfo = np.iinfo(np.dtype(base_dtype))
        mn = int(np.min(synth))
        mx = int(np.max(synth))
        if mn < iinfo.min or mx > iinfo.max:
            _fail(f"synth token 超出目标 dtype 范围：[{mn},{mx}] vs {iinfo}")
        print(f"[synth] 转换 dtype: {synth.dtype} -> {np.dtype(base_dtype)}")
        synth = synth.astype(base_dtype, copy=False)

    # 2.5) 采样（若配置了 K）
    if SYNTH_SAMPLE_K is None:
        sel_idx = np.arange(M)
    else:
        K = int(SYNTH_SAMPLE_K)
        if K < 0:
            _fail(f"SYNTH_SAMPLE_K 必须为非负整数，当前 {SYNTH_SAMPLE_K}")
        if SYNTH_SAMPLE_REPLACE:
            rng = np.random.default_rng(SYNTH_SAMPLE_SEED)
            sel_idx = rng.integers(0, M, size=K)
        else:
            if K > M:
                print(f"WARNING: 采样 K={K} 大于可用样本 M={M}，将使用 K=M。")
                K = M
            rng = np.random.default_rng(SYNTH_SAMPLE_SEED)
            sel_idx = rng.choice(M, size=K, replace=False)
            sel_idx.sort()
    S = int(sel_idx.shape[0])
    print(f"[sample] use {S} / {M} synth rows (seed={SYNTH_SAMPLE_SEED}, replace={SYNTH_SAMPLE_REPLACE})")

    # 3) 规则检查
    if BASE_KEEP > base_n:
        _fail(f"BASE_KEEP={BASE_KEEP} 超过原始总数 {base_n}")
    if not (0 <= INSERT_POS <= BASE_KEEP):
        _fail(f"INSERT_POS={INSERT_POS} 不在 [0, {BASE_KEEP}] 范围内")

    expected_total = BASE_KEEP + S
    print(f"[plan] keep base[0:{BASE_KEEP}), insert synth({S}) at pos={INSERT_POS} -> total={expected_total}")

    # 4) 写新前缀（mmapped）
    out_bin = OUT_PREFIX + ".bin"
    out_idx = OUT_PREFIX + ".idx"
    if os.path.exists(out_bin) or os.path.exists(out_idx):
        _fail(f"输出文件已存在：{out_bin} / {out_idx}，为避免覆盖，请先删或换 OUT_PREFIX")

    print(f"[write] -> {out_bin} / {out_idx} (dtype={np.dtype(base_dtype)})")
    target_dtype_type = np.dtype(base_dtype).type  # 例如 <class 'numpy.uint16'>
    builder = MMapIndexedDatasetBuilder(out_bin, dtype=target_dtype_type)

    # 写入 synth 的小工具（写入所选样本）
    def write_synth_block():
        if S == 0:
            print("[synth] skip (S=0).")
            return
        t0 = time.time()
        for j, k in enumerate(sel_idx):
            row = synth[k]
            if row.ndim != 1:
                row = row.reshape(-1)
            if row.shape[0] != T:
                _fail(f"synth[{k}] 长度 {row.shape[0]} != {T}")
            builder.add_item(row)
            builder.end_document()
            if (j + 1) % 10000 == 0:
                dt = time.time() - t0
                print(f"[synth] {j+1}/{S} rows written, {dt:.1f}s")
        print(f"[synth] done ({S} rows).")

    # 若 pos==0，先写 synth
    if INSERT_POS == 0:
        write_synth_block()

    # 依次写 base[0:BASE_KEEP)，在 INSERT_POS 处插入 synth
    t0 = time.time()
    inserted = False
    for i in range(BASE_KEEP):
        if (not inserted) and (i == INSERT_POS) and (INSERT_POS != 0):
            write_synth_block()
            inserted = True

        arr = np.asarray(base_ds[i])
        if arr.ndim != 1:
            arr = arr.reshape(-1)
        if arr.shape[0] != T:
            _fail(f"base[{i}] 长度 {arr.shape[0]} != {T}")
        # 确保 dtype 匹配
        if arr.dtype != base_dtype:
            arr = arr.astype(base_dtype, copy=False)
        builder.add_item(arr)
        builder.end_document()

        if (i + 1) % 100000 == 0:
            dt = time.time() - t0
            print(f"[base] {i+1}/{BASE_KEEP} rows written, {dt:.1f}s")

    # 如果 INSERT_POS==BASE_KEEP（追加尾部）且还没插入，最后再插
    if (INSERT_POS == BASE_KEEP) and (not inserted):
        write_synth_block()

    # 5) 完成索引
    builder.finalize(out_idx)
    print(f"[done] wrote: {out_bin} / {out_idx}")
    print(f"[stat] total={expected_total}, synth_written={S}, base_kept={BASE_KEEP}, insert_pos={INSERT_POS}")
    print(f"请用你的 inspect 脚本验证：impl=mmap, dtype={np.dtype(base_dtype)}, num_samples={expected_total}, sample_len={T}。")

if __name__ == "__main__":
    main()
