from common_imports import *

def get_cifar100_dataset(normalize=False):
    # Define the needed transformations
    required_transform = None
    if normalize:
        # The default normalization is mean=0.5 and std=0.5 for each channel
        required_transform = [  
                                v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10),
                                v2.ToTensor(),
                                # v2.RandomHorizontalFlip(),
                                # v2.RandomCrop(32, padding=4),
                                # v2.RandomRotation(10),
                            #   v2.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
                                # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
                                # v2.RandomErasing(), 
                                v2.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                            ]
    else:
        required_transform = [  
                                v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10),
                                v2.ToTensor(),
                                # v2.RandomHorizontalFlip(), 
                                # v2.RandomCrop(32, padding=4),
                                # v2.RandomRotation(10),
                            #   v2.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
                                # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
                                # v2.RandomErasing()
                            ]

    train_dataset= datasets.CIFAR100(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose(required_transform),
        )

    test_dataset= datasets.CIFAR100(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([v2.ToTensor()]),
        )
    
    return train_dataset, test_dataset

def get_cifar100_dataset_without_transform():

    train_dataset= datasets.CIFAR100(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose([v2.ToTensor()]),
        )

    test_dataset= datasets.CIFAR100(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([v2.ToTensor()]),
        )
    
    return train_dataset, test_dataset

def get_cifar100_dataset_with_only_normalization(norm_params = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]):

    train_dataset= datasets.CIFAR100(
            "./data/",
            train=True,
            download=True,
            transform=transforms.Compose([v2.ToTensor(), v2.Normalize(*norm_params),]),
        )

    test_dataset= datasets.CIFAR100(
            "./data/",
            train=False,
            download=True,
            transform=transforms.Compose([v2.ToTensor(), v2.Normalize(*norm_params),]),
        )
    
    return train_dataset, test_dataset