import numpy as np
from torch.utils.data import Dataset

class ReindexDataset(Dataset):
    def __init__(self, base_ds: Dataset, order: np.ndarray):
        self.base = base_ds
        self.order = np.asarray(order, dtype=np.int64)
        if self.order.ndim != 1:
            raise ValueError("order must be 1-D")
        if self.order.min() < 0 or self.order.max() >= len(base_ds):
            raise ValueError("order indices out of range")
        # 允许重复，允许长度变长

    def __len__(self):
        return len(self.order)

    def __getitem__(self, i):
        return self.base[int(self.order[i])]

