# fixidx.py — 将 /root/trainbin1/document1.idx 的 pointers 从“元素偏移”改为“字节偏移”（×2）
import os, struct
import numpy as np

PREFIX = "/root/trainbin1/document1"
IDX = PREFIX + ".idx"
BIN = PREFIX + ".bin"
BAK = PREFIX + ".idx.bak"
ITEMSIZE = 2  # uint16 -> 2 bytes

# 读取头 + 数组
with open(IDX, "rb") as f:
    magic = f.read(9)
    version = struct.unpack("<Q", f.read(8))[0]
    dtype_code = struct.unpack("<B", f.read(1))[0]
    n = struct.unpack("<Q", f.read(8))[0]
    doc_n = struct.unpack("<Q", f.read(8))[0]
    sizes = np.fromfile(f, dtype=np.uint32, count=n)
    pointers = np.fromfile(f, dtype=np.uint64, count=n)
    doc_idx = np.fromfile(f, dtype=np.int32, count=doc_n) if doc_n else None

bin_bytes = os.path.getsize(BIN)

# 判断当前 pointers 是否已经是“字节单位”
last_end_bytes_if_bytes = int(pointers[-1] + sizes[-1] * ITEMSIZE)
already_bytes = np.all(pointers % ITEMSIZE == 0) and (last_end_bytes_if_bytes <= bin_bytes)

if already_bytes:
    print("pointers 已是字节偏移，无需修改。")
    raise SystemExit(0)

# 认为当前是“元素单位”，改为“字节单位”
new_pointers = (pointers * ITEMSIZE).astype(np.uint64)

# 一致性检查
last_end_bytes = int(new_pointers[-1] + sizes[-1] * ITEMSIZE)
assert last_end_bytes <= bin_bytes, f"越界: last_end_bytes={last_end_bytes} > bin_bytes={bin_bytes}"

# 备份并写回
if not os.path.exists(BAK):
    os.rename(IDX, BAK)
    print("backup ->", BAK)
with open(IDX, "wb") as f:
    f.write(magic)
    f.write(struct.pack("<Q", version))
    f.write(struct.pack("<B", dtype_code))
    f.write(struct.pack("<Q", n))
    f.write(struct.pack("<Q", 0 if doc_idx is None else doc_n))
    sizes.astype(np.uint32).tofile(f)
    new_pointers.tofile(f)
    if doc_idx is not None:
        np.asarray(doc_idx, dtype=np.int32).tofile(f)

print("wrote fixed idx:", IDX)
print("last_end_bytes:", last_end_bytes, "bin_bytes:", bin_bytes)
