'''
single file for all the datasets available in pytorch
'''

import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST

def get_dataset(args):
    
    if args.data=="MNIST" or args.data=="mnist":
        train_dataset = MNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
        test_dataset = MNIST('data', download=True, train=False, transform=transforms.Compose([transforms.ToTensor()]))
    
    elif args.data=="CIFAR10" or args.data=="cifar10":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            ])
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=transform_train)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=transform_test)
    
    elif args.data=="FMNIST" or args.data=="fmnist":
        train_dataset = FashionMNIST('data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
        test_dataset = FashionMNIST('data', download=True, train=False, transform=transforms.Compose([transforms.ToTensor()]))
    
    return train_dataset, test_dataset