import torch
from torchvision import datasets, transforms
from .base_dataset import DiffusionDataset, DATA_DIR
from os import path

class MNISTDigitsDataset(DiffusionDataset):
    def __init__(self):
        super().__init__()
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        self.mnist = datasets.MNIST(path.join(DATA_DIR, 'data'), 
                                    download=True, 
                                    transform=transform)

        
    def __getitem__(self, idx):
        return self.mnist[idx][0] #Only return images, not labels

    def __len__(self):
        return len(self.mnist)