# megatron/data/permutation_dataset.py
import numpy as np
from torch.utils.data import Dataset

class PermutationDataset(Dataset):
    def __init__(self, base_ds: Dataset, order: np.ndarray):
        self.base = base_ds
        self.order = np.asarray(order, dtype=np.int64)
        if self.order.ndim != 1:
            raise ValueError("order must be 1-D")
        if len(self.order) != len(base_ds):
            raise ValueError(f"order length {len(self.order)} != base dataset length {len(base_ds)}")
        if self.order.min() < 0 or self.order.max() >= len(base_ds):
            raise ValueError("order indices out of range")
        # 确保是真置换（不丢不重）
        if len(np.unique(self.order)) != len(self.order):
            raise ValueError("order must be a permutation (unique indices)")

    def __len__(self):
        return len(self.order)

    def __getitem__(self, i):
        return self.base[int(self.order[i])]
