import torch


class Dataset(object):
    def __init__(self, loc, transform=None):
        self.dataset = torch.load(loc).float().div(255)
        self.transform = transform

    def __len__(self):
        return self.dataset.size(0)

    @property
    def ndim(self):
        return self.dataset.size(1)

    def __getitem__(self, index):
        x = self.dataset[index]
        x = self.transform(x) if self.transform is not None else x
        return x, 0


class CelebA(Dataset):
    TRAIN_LOC = 'data/celeba/celeba_train.pth'
    VAL_LOC = 'data/celeba/celeba_val.pth'

    def __init__(self, train=True, transform=None):
        return super(CelebA, self).__init__(self.TRAIN_LOC if train else self.VAL_LOC, transform)
