import os

from typing import *
import torchvision.datasets as dsets

import torchvision.utils
import torchvision.transforms as transforms

IMAGENET_LOC_ENV = ""

DATASETS = ["mnist", "cifar10", "tinyimagenet"]


def dataset_load(dataset: str, split: str):
    if dataset == 'mnist':
        return _mnist(split)
    elif dataset == 'cifar10':
        return _cifar10(split)
    elif dataset == 'tinyimagenet':
        return _tinyimagenet(split)
    else:
       print('Incompatible dataset type')
            
            
def _tinyimagenet(split: str):              
    if split == "train":
        augmentation = transforms.RandomApply([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomResizedCrop(64)], p=.8)   
    
        subdir = os.path.join(IMAGENET_LOC_ENV, "train")        
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.convert("RGB")),
            augmentation,
            transforms.ToTensor()
            ])
    elif split == "test":
        subdir = os.path.join(IMAGENET_LOC_ENV, "val")        
        transform = transforms.Compose([
            transforms.Lambda(lambda x: x.convert("RGB")),
            transforms.ToTensor()#,
            #normalize
            ])  
    return dsets.ImageFolder(subdir, transform) 

def _mnist(split: str):            
    if split == "train":    
        return dsets.MNIST(root='./data/',
                              train=True,
                              transform=transforms.ToTensor(),
                              download=True)  
    elif split == "test":                            
        return dsets.MNIST(root='./data/',
                             train=False,
                             transform=transforms.ToTensor(),
                             download=True)
                                 
def _cifar10(split: str):
    if split == "train":
        return dsets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transforms.ToTensor())
    elif split == "test":
        return dsets.CIFAR10(root='./data', train=False,
                                      download=True, transform=transforms.ToTensor())



    
