import os, struct, numpy as np

MERGED_IDX = "/root/trainbin1/document.idx"       # 合并 idx，用来复用头信息并推断窗口长度
SHARD_BIN  = "/root/trainbin1/document1.bin"      # 分片 bin（不改）
OUT_IDX    = "/root/trainbin1/document1.idx"      # 目标 idx（将生成/覆盖）
ITEMSIZE   = 2                                     # 你的数据实际是 uint16 -> 每 token 2 字节

def read_idx_header_and_sizes(idx_path, read_all_sizes=True, max_preview=10000000):
    with open(idx_path, "rb") as f:
        magic = f.read(9)
        version = struct.unpack("<Q", f.read(8))[0]
        dtype_code = struct.unpack("<B", f.read(1))[0]
        sizes_len = struct.unpack("<Q", f.read(8))[0]
        doc_len = struct.unpack("<Q", f.read(8))[0]
        if read_all_sizes:
            sizes = np.fromfile(f, dtype=np.uint32, count=sizes_len)
        else:
            # 预览模式：只读一部分 sizes
            sizes = np.fromfile(f, dtype=np.uint32, count=min(sizes_len, max_preview))
        # 跳过 pointers 以便快速返回 sizes 预览（不在预览模式下读取以节省时间/IO）
    return magic, version, dtype_code, sizes

def write_idx(idx_path, magic, version, dtype_code, sizes, pointers, doc_idx=None):
    sizes = np.asarray(sizes, dtype=np.uint32)
    pointers = np.asarray(pointers, dtype=np.uint64)
    with open(idx_path, "wb") as f:
        f.write(magic)
        f.write(struct.pack("<Q", version))
        f.write(struct.pack("<B", dtype_code))
        f.write(struct.pack("<Q", len(sizes)))
        f.write(struct.pack("<Q", 0 if doc_idx is None else len(doc_idx)))
        sizes.tofile(f)
        pointers.tofile(f)
        if doc_idx is not None:
            np.asarray(doc_idx, dtype=np.int32).tofile(f)

def infer_window_length_from_sizes(sizes):
    # 取众数作为窗口长度（一般为 2049），并校验绝大多数一致
    vals, counts = np.unique(sizes, return_counts=True)
    L = int(vals[np.argmax(counts)])
    ratio = float(counts.max()) / float(counts.sum())
    return L, ratio, vals, counts

def main():
    assert os.path.exists(MERGED_IDX), f"not found: {MERGED_IDX}"
    assert os.path.exists(SHARD_BIN), f"not found: {SHARD_BIN}"

    magic, version, dtype_code, sizes_preview = read_idx_header_and_sizes(MERGED_IDX, read_all_sizes=False)
    L, ratio, vals, counts = infer_window_length_from_sizes(sizes_preview)
    print(f"merged idx preview sizes unique={vals.tolist()} counts={counts.tolist()}")
    print(f"inferred window length L={L}, agree_ratio={ratio:.4f}")

    # 也可以安全地读取全部 sizes 再确认（可选，若数据很大可跳过）
    # _, _, _, sizes_all = read_idx_header_and_sizes(MERGED_IDX, read_all_sizes=True)
    # L_all, ratio_all, vals_all, counts_all = infer_window_length_from_sizes(sizes_all)
    # print(f"full sizes unique={vals_all.tolist()} -> L={L_all}, agree_ratio={ratio_all:.4f}")

    # 根据分片 bin 大小计算 token 数
    shard_bytes = os.path.getsize(SHARD_BIN)
    assert shard_bytes % ITEMSIZE == 0, f"bin size {shard_bytes} 不是 ITEMSIZE={ITEMSIZE} 的倍数"
    shard_tokens = shard_bytes // ITEMSIZE
    num_samples = shard_tokens // L
    if num_samples == 0:
        raise RuntimeError(f"shard too small: tokens={shard_tokens}, window={L}")

    # 构造 sizes/pointers（单位：token）
    sizes_out = np.full(num_samples, L, dtype=np.uint32)
    pointers_out = (np.arange(num_samples, dtype=np.uint64) * np.uint64(L))

    last_end = int(pointers_out[-1] + sizes_out[-1])
    assert last_end <= shard_tokens, f"越界: last_end={last_end} > shard_tokens={shard_tokens}"

    # 写出新的 idx，头信息复用合并 idx（magic/version/dtype_code），doc_idx 省略为 0 长度
    write_idx(OUT_IDX, magic, version, dtype_code, sizes_out, pointers_out, doc_idx=None)
    print(f"wrote {OUT_IDX}")
    print(f"samples: {num_samples}, shard_tokens: {shard_tokens}, last_end: {last_end}")

if __name__ == "__main__":
    main()



