from torch.utils.data import Dataset


class Collective_Subset(Dataset):
    def __init__(self, orig_ds: Dataset, indices: list):
        super().__init__()

        self.true_labels = orig_ds.true_labels[indices]
        self.labels = orig_ds.true_labels[indices]
        self.images = orig_ds.images[indices]
        self.sensitives = orig_ds.sensitives[indices]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img = self.images[index]
        sensitive = self.sensitives[index]
        label = self.labels[index]
        return img, sensitive, label

    def change_labels(self, idx, new_labels):
        self.labels = self.true_labels.detach().clone()
        self.labels[idx] = new_labels

    def group_idx(self, val):
        return (self.sensitives == val).reshape(len(self))
