import os
import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0


class MaskedGPT2Dataset(Dataset):
    """
    样本级数据集：按底层索引的固有顺序提供样本，不做 shuffle/window 裁剪。
    在 [mask_start_global_idx, mask_start_global_idx+mask_length) 范围内，
    根据 mask_bool（True=保留，False=遮挡）为每个样本输出 loss_mask=全1或全0。
    """
    def __init__(
        self,
        name,
        data_prefix,
        documents,             # 保留签名但不使用
        indexed_dataset,       # 样本级 IndexedDataset: __getitem__(i) 返回长度=seq_length+1
        num_samples,           # 保留签名但不使用
        seq_length,
        seed,                  # 保留签名但不使用
        build_index_mappings=True,  # 保留签名但不使用
        use_shared_fs=True,         # 保留签名但不使用
        mask_npy_path=None,         # .npy，形状=(mask_length,)，dtype=bool，True=保留
        mask_start_global_idx=0,  # 0-based：step1001 的首样本全局索引
        mask_length=7320644,    # 1000 步 * 1024 样本/步
    ):
        self.name = name
        self.indexed_dataset = indexed_dataset
        self.seq_length = int(seq_length)

        # 总样本数直接来自底层样本级数据集
        try:
            self.total_samples = len(self.indexed_dataset)
        except Exception:
            # 退化兜底
            self.total_samples = int(self.indexed_dataset.sizes.shape[0])

        # 加载 mask
        self.mask_start = int(mask_start_global_idx)
        self.mask_len = int(mask_length)
        self.mask_end = self.mask_start + self.mask_len

        if mask_npy_path and os.path.exists(mask_npy_path):
            gm = np.load(mask_npy_path)
            if gm.dtype != np.bool_:
                gm = gm.astype(np.bool_)
            if gm.shape != (self.mask_len,):
                raise ValueError(f"mask shape {gm.shape} != ({self.mask_len},)")
            self.global_mask = gm
            print_rank_0(f" > Mask loaded: keep {int(gm.sum())} / {self.mask_len} in "
                         f"[{self.mask_start} .. {self.mask_end-1}] (0-based)")
        else:
            self.global_mask = np.ones(self.mask_len, dtype=np.bool_)
            print_rank_0(" > No mask file provided; all samples in window keep=True.")

        # 可读性边界检查（可选）
        if self.mask_start < 0 or self.mask_end > self.total_samples:
            print_rank_0(f"WARNING: mask window [{self.mask_start}, {self.mask_end}) "
                         f"exceeds dataset length {self.total_samples}; will clip at runtime.")

        print_rank_0(f" > [MaskedGPT2Dataset] total_samples={self.total_samples}, seq_length={self.seq_length}")

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx: int):
        # 直接按样本级顺序读取一条序列，期望长度 = seq_length+1
        sample = self.indexed_dataset[idx]
        arr = np.asarray(sample, dtype=np.int64)

        expected = self.seq_length + 1
        if arr.ndim != 1 or arr.shape[0] != expected:
            raise RuntimeError(
                f"Unexpected sample length at idx={idx}: got {arr.shape}, expect ({expected},)."
            )

        # 计算是否保留（True=保留，False=遮挡）
        if self.mask_start <= idx < self.mask_end:
            keep = bool(self.global_mask[idx - self.mask_start])
        else:
            keep = True

        # 转成 loss_mask（按 token 维度）
        if keep:
            loss_mask = torch.ones(self.seq_length, dtype=torch.int64)
        else:
            loss_mask = torch.zeros(self.seq_length, dtype=torch.int64)

        # 返回与原管线兼容的键
        return {
            "text": arr,              # 长度=seq_length+1，交由 get_batch 再切 tokens/labels
            "loss_mask": loss_mask,   # 长度=seq_length，float32
        }
