from common_imports import *

def get_cifar10_dataset(normalize=False):
    # Define the needed transformations
    required_transform = None
    if normalize:
        # The default normalization is mean=0.5 and std=0.5 for each channel
        required_transform = [ transforms.ToTensor(),
                               transforms.RandomHorizontalFlip(),
                               transforms.RandomCrop(32, padding=4),
                            #    transforms.RandomRotation(10),
                            #    transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
                            #    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
                               transforms.RandomErasing(), 
                               transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
    else:
        required_transform = [ transforms.ToTensor(), 
                              transforms.RandomHorizontalFlip(), 
                              transforms.RandomCrop(32, padding=4),
                            #   transforms.RandomRotation(10),
                            #   transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
                            #   transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
                              transforms.RandomErasing()
                            ]

    train_dataset= datasets.CIFAR10(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose(required_transform),
        )

    test_dataset= datasets.CIFAR10(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([ transforms.ToTensor()]),
        )
    
    return train_dataset, test_dataset

def get_cifar10_dataset_without_transform():

    train_dataset= datasets.CIFAR10(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()]),
        )

    test_dataset= datasets.CIFAR10(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
    
    return train_dataset, test_dataset

def get_cifar10_dataset_with_only_normalization(norm_params = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]):

    train_dataset= datasets.CIFAR10(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(*norm_params),]),
        )

    test_dataset= datasets.CIFAR10(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(*norm_params),]),
        )
    
    return train_dataset, test_dataset
