#%%
import torchvision.transforms as T
import torchvision.datasets as datasets
import torch
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from datasets import load_dataset

from albumentations import (
	Compose,
    HorizontalFlip,
    Normalize,
    RandomCrop,
    PadIfNeeded,
    RGBShift,
    Rotate
)
from albumentations.pytorch import ToTensor
import numpy as np
import torchvision.transforms as transforms

def albumentations_transforms(p=1.0, is_train=False):
	# Mean and standard deviation of train dataset
	mean = np.array([0.4914, 0.4822, 0.4465])
	std = np.array([0.2023, 0.1994, 0.2010])
	transforms_list = []
	# Use data aug only for train data
	if is_train:
		transforms_list.extend([
			PadIfNeeded(min_height=72, min_width=72, p=1.0),
			RandomCrop(height=64, width=64, p=1.0),
			HorizontalFlip(p=0.25),
			Rotate(limit=15, p=0.25),
			RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, p=0.25),
			#CoarseDropout(max_holes=1, max_height=32, max_width=32, min_height=8,
						#min_width=8, fill_value=mean*255.0, p=0.5),
		])
	transforms_list.extend([
		Normalize(
			mean=mean,
			std=std,
			max_pixel_value=255.0,
			p=1.0
		),
		ToTensor()
	])
	data_transforms = Compose(transforms_list, p=p)
	return lambda img: data_transforms(image=np.array(img))["image"]

def torch_transforms(is_train=False):
	# Mean and standard deviation of train dataset
	mean = (0.4914, 0.4822, 0.4465)
	std = (0.2023, 0.1994, 0.2010)
	transforms_list = []
	# Use data aug only for train data
	if is_train:
		transforms_list.extend([
			transforms.RandomCrop(64, padding=4),
			transforms.RandomHorizontalFlip(),
		])
	transforms_list.extend([
		transforms.ToTensor(),
		transforms.Normalize(mean, std),
	])
	if is_train:
		transforms_list.extend([
			transforms.RandomErasing(0.25)
		])
	return transforms.Compose(transforms_list)


def choose_dataset(dataset_name: str, batch_size: int, datapath: str, num_workers: int = 1, ):
    """
    selects a dataset by name
    """
    if dataset_name == 'cifar10':
        return load_cifar10(batch_size, num_workers)
    elif dataset_name == 'cifar10_resize':
        return load_cifar10_resize(batch_size, num_workers)
    elif dataset_name == 'cifar10_b&w':
        return load_cifar10_bw(batch_size, num_workers)
    elif dataset_name == 'fashion_mnist':
        return load_fashion_mnist(batch_size, num_workers)
    elif dataset_name == 'svhn':
        return load_svhn(batch_size, num_workers)
    elif dataset_name == 'mnist':
        return load_mnist(batch_size, num_workers)
    elif dataset_name == "imagenet":
        return load_imagenet(batch_size, num_workers, dummy=False, datapath=datapath)
    elif dataset_name == "tiny_imagenet":
        return load_tiny_imagenet(batch_size, num_workers, dummy=False, datapath=datapath)
    else:
        print("dataset not available. Exiting")
        exit(1)


def load_tiny_imagenet(batch_size: int, num_workers: int, dummy: bool = False, datapath="./data/tiny_imageNet/"):
    if dummy:
        print("=> Dummy data is used!")
        n_train = 1000  # batchsize
        n_val = 64
        train_dataset = datasets.FakeData(n_train, (3, 64, 64), 200, T.ToTensor())
        val_dataset = datasets.FakeData(n_val, (3, 64, 64), 200, T.ToTensor())
    else:

        traindir = './t_imagenet/tiny-imagenet-200/new_train'
        valdir = './t_imagenet/tiny-imagenet-200/new_test'
        # train_dataset = load_dataset('Maysee/tiny-imagenet', split='train')
        # val_dataset = load_dataset('Maysee/tiny-imagenet', split='valid')

        train_transforms = albumentations_transforms(p = 1,is_train=True)
        
        val_transforms = albumentations_transforms(p = 1,is_train= False)



        train_dataset = datasets.ImageFolder(
            traindir,train_transforms
            )

        val_dataset = datasets.ImageFolder(
            valdir,
            val_transforms)


        # Create a custom dataset class for Tiny ImageNet
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                   shuffle=True, num_workers=num_workers, pin_memory=True,
                                                   sampler=None)

        validation_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                                        shuffle=True, num_workers=num_workers, pin_memory=True,
                                                        sampler=None)
        # import matplotlib.pyplot as plt    ### check if dog labels are correct
        # for data in validation_loader:    
        #     x,y = data
        #     if int(y) == 28:
        #         plt.imshow(x.reshape(64,64,3))
        #         plt.title(y)
        #         plt.show()

    return train_loader, validation_loader, None


def load_imagenet(batch_size: int, num_workers: int, dummy: bool = False, datapath="./imageNet/"):
    """
        returns train_loader, val_loader and test_loader for the data set
        params: dummy :=  for debugging purposes, if no imagenet data is available
    """
    data = datapath
    if dummy:
        print("=> Dummy data is used!")
        n_train = 1000  # batchsize
        n_val = 64
        train_dataset = datasets.FakeData(n_train, (3, 224, 224), 1000, T.ToTensor())
        val_dataset = datasets.FakeData(n_val, (3, 224, 224), 1000, T.ToTensor())
    else:

        traindir = os.path.join(data, 'train')  # you need to store the training data in $IMAGENET_PATH/train
        valdir = os.path.join(data, 'validation')  # you need to store the training data in $IMAGENET_PATH/validation
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            T.Compose([
                T.RandomResizedCrop(224),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                normalize,
            ]))

    # Create dataloader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                               shuffle=True, num_workers=num_workers, pin_memory=True,
                                               sampler=None)

    validation_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                                    shuffle=True, num_workers=num_workers, pin_memory=True,
                                                    sampler=None)

    return train_loader, validation_loader, None


def load_cifar10(batch_size: int, num_workers: int):
    """
    returns train_loader, val_loader and test_loader for the data set
    """
    transform = T.Compose([T.ToTensor()])
    # Data transforms (normalization & data augmentation)
    stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    train_transform = T.Compose([T.RandomCrop(32, padding=4, padding_mode='reflect'),
                                 T.RandomHorizontalFlip(),
                                 T.ToTensor(),
                                 T.Normalize(*stats, inplace=True)])
    valid_transform = T.Compose([T.ToTensor(), T.Normalize(*stats)])
    # valid_transform = transform ######## old one
    # train_transform = transform ######## old one
    train_loader = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    val_loader = datasets.CIFAR10(root='./data', train=False, download=True, transform=valid_transform)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True,
                                             num_workers=num_workers)
    return train_loader, val_loader, None

def load_cifar10_resize(batch_size: int, num_workers: int):
    """
    returns train_loader, val_loader and test_loader for the data set
    """
    transform = T.Compose([T.ToTensor()])
    # Data transforms (normalization & data augmentation)
    stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    train_transform = T.Compose([T.Resize((572,572)),
                                 T.ToTensor(),
                                 T.Normalize(*stats, inplace=True)])
    valid_transform = T.Compose([T.ToTensor(), T.Normalize(*stats)])
    # valid_transform = transform ######## old one
    # train_transform = transform ######## old one
    train_loader = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    val_loader = datasets.CIFAR10(root='./data', train=False, download=True, transform=valid_transform)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True,
                                             num_workers=num_workers)
    return train_loader, val_loader, None


def load_cifar10_bw(batch_size: int, num_workers: int):
    """
        returns train_loader, val_loader and test_loader for the data set
    """
    transform = T.Compose([T.Grayscale(), T.ToTensor()])
    train_loader = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_loader = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True,
                                             num_workers=num_workers)
    return train_loader, val_loader, None


def load_fashion_mnist(batch_size: int, num_workers: int):
    """
            returns train_loader, val_loader and test_loader for the data set
        """
    trans = T.Compose([T.ToTensor()])
    train_loader = datasets.FashionMNIST(root='./data', train=True, download=True, transform=trans)
    val_loader = datasets.FashionMNIST(root='./data', train=False, download=True, transform=trans)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return train_loader, val_loader, None


def load_svhn(batch_size: int, num_workers: int):
    """
            returns train_loader, val_loader and test_loader for the data set
    """

    trans = T.Compose([T.ToTensor()])
    train_loader = datasets.SVHN(root='./data', download=True, transform=trans)
    val_loader = datasets.SVHN(root='./data', train=False, download=True, transform=trans)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return train_loader, val_loader, None


def load_mnist(batch_size: int, num_workers: int):
    """
            returns train_loader, val_loader and test_loader for the data set
    """

    trans = T.Compose([T.ToTensor()])
    train_loader = datasets.MNIST(root='./data', train=True, download=True, transform=trans)
    val_loader = datasets.MNIST(root='./data', train=False, download=True, transform=trans)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return train_loader, val_loader, None