import pytorch_lightning as pl
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, cfg=None, name="", **kwargs):
        super().__init__()
        self.data_dir = data_dir
        self.name = name
        self.cfg = cfg
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    #def prepare_data(self):
        # Downloads the MNIST dataset
        # datasets.MNIST(self.data_dir, train=True, download=True)
        # datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.mnist_train = datasets.MNIST(
                self.data_dir, train=True, transform=self.transform, download=True)
            self.mnist_val = datasets.MNIST(
                self.data_dir, train=False, transform=self.transform, download=True)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.cfg.batch_size, shuffle=False, num_workers=self.cfg.num_workers)
