"""Implements MNIST DataModule."""
import os 
from torch.utils.data import DataLoader,ConcatDataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from sklearn.model_selection import train_test_split
from .dataset_wrapper import CustomDataset

class MNISTDataModule():
    def __init__(self,
                 *,
                 data_dir,
                 batch_size,
                 ntrain=None,
                 train_discard_classes = None,
                 train_transform = None,
                 train_target_transform = None,
                 test_discard_classes = None,
                 test_transform = None,
                 test_target_transform = None,
                 val_ratio=0.2,
                 num_workers=1,
                 un_augmented_transform=None,
                 test_batch=None
                 ):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.ntrain = ntrain
        self.train_discard_classes = train_discard_classes
        self.train_transform = train_transform
        self.train_target_transform = train_target_transform
        self.test_discard_classes = test_discard_classes
        self.test_transform = test_transform
        self.test_target_transform = test_target_transform
        self.num_workers = num_workers
        self.val_ratio = val_ratio
        self.un_augmented_transform = un_augmented_transform
        self.test_batch = test_batch if test_batch is not None else batch_size
        self.prepare_data()
        self.setup()

    def prepare_data(self):
        # download the data
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self):
        train_data = MNIST(self.data_dir, train=True)
        test_data = MNIST(self.data_dir, train=False)
        test_data_unaug = MNIST(self.data_dir, train=False)

        train_data = CustomDataset(dataset=train_data,
                                   transform=self.train_transform,
                                    target_transform=self.train_target_transform,
                                    discard_classes=self.train_discard_classes)
        
        test_data = CustomDataset(dataset=test_data,
                                  transform=self.test_transform,
                                  target_transform=self.test_target_transform,
                                  discard_classes=self.test_discard_classes) 

        if self.un_augmented_transform is not None:
            self.un_augmented_testset = CustomDataset(dataset=test_data_unaug,
                                                        transform=self.un_augmented_transform,
                                                        target_transform=self.test_target_transform,
                                                        discard_classes=self.test_discard_classes)
                                   
        if self.ntrain is not None:
            num_train = self.ntrain
            num_val = len(train_data) - num_train
        else:
            num_val = int(len(train_data) * self.val_ratio)
            num_train = len(train_data) - num_val

        self.train_data, self.val_data = random_split(train_data, [num_train, num_val])
        self.test_data = test_data

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.test_batch, num_workers=self.num_workers)

    def un_augmented_test_dataloader(self):
        if self.un_augmented_transform is None:
            return None
        return DataLoader(self.un_augmented_testset, batch_size=self.test_batch, num_workers=self.num_workers)

class CIFAR10dataModule():
    def __init__(self,
                 *,
                 data_dir,
                 batch_size,
                 ntrain=None,
                 train_discard_classes = None,
                 train_transform = None,
                 train_target_transform = None,
                 test_discard_classes = None,
                 test_transform = None,
                 test_target_transform = None,
                 val_ratio=0.2,
                 num_workers=1,
                 un_augmented_transform=None,
                 test_batch=None
                 ):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.ntrain = ntrain
        self.train_discard_classes = train_discard_classes
        self.train_transform = train_transform
        self.train_target_transform = train_target_transform
        self.test_discard_classes = test_discard_classes
        self.test_transform = test_transform
        self.test_target_transform = test_target_transform
        self.num_workers = num_workers
        self.val_ratio = val_ratio
        self.un_augmented_transform = un_augmented_transform
        self.test_batch = test_batch if test_batch is not None else batch_size
        self.prepare_data()
        self.setup()

    def prepare_data(self):
        # download the data
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self):
        train_data = CIFAR10(self.data_dir, train=True)
        test_data = CIFAR10(self.data_dir, train=False)
        test_data_unaug = CIFAR10(self.data_dir, train=False)

        train_data = CustomDataset(dataset=train_data,
                                   transform=self.train_transform,
                                    target_transform=self.train_target_transform,
                                    discard_classes=self.train_discard_classes)
        
        test_data = CustomDataset(dataset=test_data,
                                  transform=self.test_transform,
                                  target_transform=self.test_target_transform,
                                  discard_classes=self.test_discard_classes) 

        if self.un_augmented_transform is not None:
            self.un_augmented_testset = CustomDataset(dataset=test_data_unaug,
                                                        transform=self.un_augmented_transform,
                                                        target_transform=self.test_target_transform,
                                                        discard_classes=self.test_discard_classes)
                                   
        # split the data
        if self.ntrain is not None:
            num_train = self.ntrain
            num_val = len(train_data) - num_train
        else:
            num_val = int(len(train_data) * self.val_ratio)
            num_train = len(train_data) - - num_val

        self.train_data, self.val_data = random_split(train_data, [num_train, num_val])
        self.test_data = test_data

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.test_batch, num_workers=self.num_workers)

    def un_augmented_test_dataloader(self):
        if self.un_augmented_transform is None:
            return None
        return DataLoader(self.un_augmented_testset, batch_size=self.test_batch, num_workers=self.num_workers)