import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
import numpy as np

class LitMNIST(pl.LightningDataModule):
    def __init__(self,normalized=False,val_set=False):
        super().__init__()
        self.normalized=normalized
        self.val_set = val_set

    def prepare_data(self):
        # download only
        MNIST(os.path.join(os.getcwd(),'data'), train=True, download=True)
        MNIST(os.path.join(os.getcwd(),'data'), train=False, download=True)

    def setup(self):
        # no download, just transform
        if self.normalized:
            transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        else:
            transform=transforms.Compose([transforms.ToTensor()])
        self.mnist_train = MNIST(os.path.join(os.getcwd(),'data'), train=True, download=False,transform=transform)
        self.mnist_test = MNIST(os.path.join(os.getcwd(),'data'), train=False, download=False,transform=transform)

        if self.val_set:
            targets = self.mnist_train.targets
            self.train_idx, self.val_idx = train_test_split(np.arange(len(targets)),test_size=0.1,stratify=targets)

    def train_dataloader(self,bs=64):
        if self.val_set:
            train_sampler = torch.utils.data.SubsetRandomSampler(self.train_idx)
            return DataLoader(self.mnist_train, batch_size=bs,shuffle=False,num_workers=3 if bs>1 else 0,sampler=train_sampler)
        else:
            return DataLoader(self.mnist_train, batch_size=bs,shuffle=True,num_workers=3 if bs>1 else 0)

    def val_dataloader(self,bs=64):
        if self.val_set:
            val_sampler = torch.utils.data.SubsetRandomSampler(self.val_idx)
            return DataLoader(self.mnist_train, batch_size=bs,shuffle=False,num_workers=3 if bs>1 else 0,sampler=val_sampler)
        else:
            return None

    def test_dataloader(self,bs=64):
        return DataLoader(self.mnist_test, batch_size=bs,shuffle=False,num_workers=3 if bs>1 else 0)

if __name__ == '__main__':
    data = LitMNIST()
    data.prepare_data()
