import os
import numpy as np

#Path: randommask.py
# based on indicies .npy 的行数生成随机布尔掩码（True=保留，False=屏蔽）
# 逻辑修改：
# 1. 移除 txt 输出
# 2. 增加基于 1024 (UNIT_SIZE) 的范围限制，只在指定范围内随机屏蔽

# ======== 配置（按需修改） ========
NPY_PATH = r""  # 若存在，用其行数作为总长度
ORIG_LEN = 20* 102400                                                # 兜底总长度（当找不到 NPY 时）

OUTPUT_MASK_PATH = r"\.npy"
SEED = 42                                                            # 固定随机种子；设为 None 则每次不同

# ---- 新增：范围控制配置 ----
UNIT_SIZE = 1024      # 基础单元大小
START_UNIT = 0        # 起始单元索引（包含） -> 对应索引 START_UNIT * 1024
END_UNIT = 800        # 结束单元索引（不包含） -> 对应索引 END_UNIT * 1024
                      # 此时生效范围为: [0, 204800)

NUM_MASK = 80000       # 在上述 [START, END) 范围内随机屏蔽的数量
                      # 注意：NUM_MASK 必须小于等于该范围内的样本总数

# ======== 工具 ========
def ensure_dir_for_file(path: str):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)

def get_total_len() -> int:
    if NPY_PATH and os.path.exists(NPY_PATH):
        # 仅取行数；兼容 shape=(N,), (N, k) 等
        arr = np.load(NPY_PATH, mmap_mode="r")
        return int(arr.shape[0])
    return int(ORIG_LEN)

# ======== 主逻辑 ========
def main():
    total_len = get_total_len()
    print(f"总数据长度 = {total_len}（来源: {'NPY_PATH' if os.path.exists(NPY_PATH) else 'ORIG_LEN'}）")

    # 计算具体的索引范围
    start_idx = START_UNIT * UNIT_SIZE
    end_idx = END_UNIT * UNIT_SIZE
    
    # 边界检查，防止超出 total_len
    if start_idx >= total_len:
        raise ValueError(f"起始索引 {start_idx} 超出了总长度 {total_len}")
    
    real_end_idx = min(end_idx, total_len)
    range_len = real_end_idx - start_idx
    
    print(f"指定屏蔽范围: Unit [{START_UNIT}, {END_UNIT}) -> Index [{start_idx}, {real_end_idx})")
    print(f"该范围内样本数: {range_len}")
    print(f"计划在该范围内随机屏蔽: {NUM_MASK} 条")

    # 校验 NUM_MASK 是否合法
    if range_len <= 0:
        raise ValueError("计算出的范围长度 <= 0，请检查 START_UNIT 和 END_UNIT。")
    if not (0 <= NUM_MASK <= range_len):
        raise ValueError(f"NUM_MASK ({NUM_MASK}) 必须在范围长度 [0, {range_len}] 内")

    rng = np.random.default_rng(SEED)
    
    # 1. 初始化全为 True (全保留)
    mask = np.ones(total_len, dtype=bool)

    # 2. 在指定范围内生成随机索引
    #    rng.choice(range_len) 生成的是 0 到 range_len-1 的相对偏移量
    #    加上 start_idx 转换为全局索引
    relative_chosen = rng.choice(range_len, size=NUM_MASK, replace=False)
    global_chosen = relative_chosen + start_idx
    
    # 3. 将选中位置设为 False (屏蔽)
    mask[global_chosen] = False

    # 4. 保存
    ensure_dir_for_file(OUTPUT_MASK_PATH)
    np.save(OUTPUT_MASK_PATH, mask)
    print(f"已保存掩码: {OUTPUT_MASK_PATH}")
    
    # 5. 统计信息
    masked_count = int((~mask).sum())
    kept_count = int(mask.sum())
    print(f"统计: 总长度={mask.size}")
    print(f"      True (保留) = {kept_count}")
    print(f"      False(屏蔽) = {masked_count} (应等于 {NUM_MASK})")
    
    # 原有的 txt 导出代码已被移除

if __name__ == "__main__":
    main()
