import torch
import torchvision
import numpy as np
import os
import torchvision.datasets.imagenet
import torchvision.transforms as transforms
from utils import GaussianBlur
from rotated_image_folder import RotatedImageFolder
from PIL import Image, ImageDraw

class CircularCrop:
    def __init__(self, args):
        self.circular_transform = args.circular_transform
        self.circular_range = args.circular_range
        if self.circular_range != 0:
            self.candidates = np.arange(self.circular_range)
        else:
            self.candidates = np.array([0])
    def __call__(self, img):
        width, height = img.size
        if width != height:
            raise ValueError('Not an Square')
        random_numbers = np.random.choice(self.candidates, size=4, replace=True)
        if self.circular_transform:
            mask = Image.new('L', (width, height), 0)
            draw = ImageDraw.Draw(mask)
            draw.ellipse((random_numbers[0], random_numbers[1], width-random_numbers[2]-1, height-random_numbers[3]-1), fill=255)

            result = Image.new('RGB', (width, height))
            result.paste(img, (0, 0), mask=mask)
        else:
            # print('Without Circular Crop')
            result = img
        return result
    
class ContrastiveLearningTransform:
    def __init__(self, args):
        if args.rotation_degree == 0:
            print('Without Random Rotation')
            if args.pretrain_set=='stl10':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(97, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=9),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.43, 0.42, 0.39],
                                        std=[0.27, 0.26, 0.27])
                ])
            elif args.pretrain_set=='stl10-R':
                if args.N == 4:
                    rotation_degrees = [0, 90, 180,270]
                elif args.N == 8:
                    rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(97, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomChoice([
                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                        for angle in rotation_degrees
                    ]),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=9),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.43, 0.42, 0.39],
                                        std=[0.27, 0.26, 0.27])
                ])
            elif args.pretrain_set=='imagenet100':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(225, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=23),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                ])
            elif args.pretrain_set=='caltech256':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(225, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=23),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                ])
            elif args.pretrain_set=='cifar10-essl':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(33, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
                ])
            elif args.pretrain_set=='cifar10':
                self.transform = transforms.Compose([
                    transforms.RandomCrop(33, padding=4),
                    CircularCrop(args),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0,62.1,66.7]])
                ])
        
        else:
            print('With Rotation')
            max_rotation_degree = args.rotation_degree
            if args.rotation_degree == 360:
                min_rotation_degree = 0
            else:
                min_rotation_degree = -args.rotation_degree
            print(min_rotation_degree ,max_rotation_degree)
            if args.pretrain_set=='stl10':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(97, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomRotation((min_rotation_degree, max_rotation_degree), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=9),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.43, 0.42, 0.39],
                                        std=[0.27, 0.26, 0.27])
                ])
            elif args.pretrain_set=='imagenet100':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(225, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomRotation((min_rotation_degree, max_rotation_degree), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=23),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                ])
            elif args.pretrain_set=='caltech256':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(225, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomRotation((min_rotation_degree, max_rotation_degree), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=23),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                ])
            elif args.pretrain_set=='cifar10-essl':
                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(33, scale=(0.2, 1.0)),
                    CircularCrop(args),
                    transforms.RandomRotation((min_rotation_degree, max_rotation_degree), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.RandomHorizontalFlip(p=0.5),   
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
                ])
            elif args.pretrain_set=='cifar10':
                self.transform = transforms.Compose([
                    transforms.RandomCrop(33, padding=4),
                    CircularCrop(args),
                    transforms.RandomRotation((min_rotation_degree, max_rotation_degree), interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x/255.0 for x in [63.0,62.1,66.7]])
                ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform(x)
        return y1, y2


def load_pretrain_datasets(args):

    if args.pretrain_set=='stl10':
        dataset = torchvision.datasets.ImageFolder(args.data / 'unlabeled', ContrastiveLearningTransform(args))
    elif args.pretrain_set=='stl10-R':
        dataset = torchvision.datasets.ImageFolder(args.data / 'unlabeled', ContrastiveLearningTransform(args))
    elif args.pretrain_set=='imagenet100':
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))
    elif args.pretrain_set=='caltech256':
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))
    elif args.pretrain_set=='cifar10-essl':
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))
    elif args.pretrain_set=='cifar10':
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))
    return dataset


def load_eval_datasets(args):
        
    # set evaluation transform
    if args.pretrain_set=='stl10':
        mean = [0.43, 0.42, 0.39]
        std  = [0.27, 0.26, 0.27]
        transform = transforms.Compose([transforms.Resize(97),
                                        CircularCrop(args),
                                        transforms.CenterCrop(97),
                                        transforms.ToTensor()])
    
    elif args.pretrain_set=='stl10-R':
        mean = torch.tensor([0.43, 0.42, 0.39])
        std = torch.tensor([0.27,0.26,0.27])
        if args.N == 4:
            rotation_degrees = [0, 90, 180,270]
        elif args.N == 8:
            rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]
        transform = transforms.Compose([transforms.Resize(97),
                                        CircularCrop(args),
                                        transforms.CenterCrop(97),
                                        transforms.RandomChoice([
                                            transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                                            for angle in rotation_degrees
                                        ]),
                                        transforms.ToTensor()])
            
    elif args.pretrain_set=='imagenet100':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.Resize((225, 225)),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.ToTensor()])
        
    elif args.pretrain_set=='imagenet100-R':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        if args.N == 4:
            rotation_degrees = [0, 90, 180,270]
        elif args.N == 8:
            rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]

        transform = transforms.Compose([transforms.Resize((225, 225)),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.RandomChoice([
                                            transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                                            for angle in rotation_degrees
                                        ]),
                                        transforms.ToTensor()])
            
    elif args.pretrain_set=='caltech256':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.Resize(225),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.ToTensor()])
    elif args.pretrain_set=='caltech256-R':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        if args.N == 4:
            rotation_degrees = [0, 90, 180,270]
        elif args.N == 8:
            rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]
        transform = transforms.Compose([transforms.Resize(225),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.RandomChoice([
                                            transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                                            for angle in rotation_degrees
                                        ]),
                                        transforms.ToTensor()])
    elif args.pretrain_set=='cifar10-essl':
        mean = torch.tensor([0.4914, 0.4822, 0.4465])
        std = torch.tensor([0.2023, 0.1994, 0.2010])
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.ToTensor()])
    elif args.pretrain_set=='cifar10-essl-R':
        mean = torch.tensor([0.4914, 0.4822, 0.4465])
        std = torch.tensor([0.2023, 0.1994, 0.2010])
        if args.N == 4:
            rotation_degrees = [0, 90, 180,270]
        elif args.N == 8:
            rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.RandomChoice([
                                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                                        for angle in rotation_degrees
                                        ]),
                                        transforms.ToTensor()])
    elif args.pretrain_set=='cifar10':
        mean = torch.tensor([x/255.0 for x in [125.3, 123.0, 113.9]])
        std = torch.tensor([x/255.0 for x in [63.0,62.1,66.7]])
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.ToTensor()])
    elif args.pretrain_set=='cifar10-R':
        mean = torch.tensor([x/255.0 for x in [125.3, 123.0, 113.9]])
        std = torch.tensor([x/255.0 for x in [63.0,62.1,66.7]])
        if args.N == 4:
            rotation_degrees = [0, 90, 180, 270]
        elif args.N == 8:
            rotation_degrees = [0, 45, 90, 135, 180, 225, 270, 315]
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.RandomChoice([
                                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                                        for angle in rotation_degrees
                                        ]),
                                        transforms.ToTensor()])
    normalize = transforms.Normalize(mean, std)
    
    if args.rotated:
        transform = [transform, normalize]
    else:
        transform = transforms.Compose([transform, normalize])
        
    # set evaluation dataset
    if args.eval_set == 'stl10':
        num_classes=10
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)
            
    elif args.eval_set == 'imagenet100':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'val', args.N, transform)
            
    elif args.eval_set == 'stanford_cars':
        num_classes=196
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)

    elif args.eval_set == 'fgvc_aircraft':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'val', args.N, transform)
            
    elif args.eval_set == 'cub_200_2011':
        num_classes=200
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)
    
    elif args.eval_set == 'cifar10':
        num_classes=10
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)
    
    elif args.eval_set == 'cifar100':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)
            
    elif args.eval_set == 'caltech256':
        num_classes=256
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)

    elif args.eval_set == 'MTARSI':
        num_classes = 20
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', args.N, transform)
            val_dataset = RotatedImageFolder(args.data / 'test', args.N, transform)
            
    return train_dataset, val_dataset, num_classes

def load_eval_rotated_sets(args):
    # set evaluation transform
    if args.pretrain_set=='stl10':
        mean = [0.43, 0.42, 0.39]
        std  = [0.27, 0.26, 0.27]
        num_classes = 10
        
        val_datasets = []
        
        for i in range(72):
            degree = (5 * i)
            transform = transforms.Compose([transforms.Resize(97),
                                            CircularCrop(args),
                                            transforms.CenterCrop(97),
                                            transforms.RandomRotation((degree, degree), interpolation=transforms.InterpolationMode.BILINEAR),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])
            
            val_datasets.append(torchvision.datasets.ImageFolder(args.data / 'test', transform))
    
    elif args.pretrain_set=='imagenet100':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        num_classes = 100
        
        val_datasets = []
        
        for i in range(72):
            degree = (5 * i)
            transform = transforms.Compose([transforms.Resize((225, 225)),
                                            CircularCrop(args),
                                            transforms.CenterCrop(225),
                                            transforms.RandomRotation((degree, degree), interpolation=transforms.InterpolationMode.BILINEAR),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])
            

            val_datasets.append(torchvision.datasets.ImageFolder(args.data / 'val', transform))

    
    elif args.pretrain_set=='caltech256':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        num_classes = 256
        
        val_datasets = []
        
        for i in range(72):
            degree = (5 * i)
            transform = transforms.Compose([transforms.Resize(225),
                                            CircularCrop(args),
                                            transforms.CenterCrop(225),
                                            transforms.RandomRotation((degree, degree), interpolation=transforms.InterpolationMode.BILINEAR),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])

            val_datasets.append(torchvision.datasets.ImageFolder(args.data / 'test', transform))
    
    elif args.pretrain_set=='cifar10-essl':
        mean = torch.tensor([0.4914, 0.4822, 0.4465])
        std = torch.tensor([0.2023, 0.1994, 0.2010])
        num_classes = 10
        
        val_datasets = []
        
        for i in range(72):
            degree = (5 * i)
            transform = transforms.Compose([transforms.Resize(33),
                                            CircularCrop(args),
                                            transforms.CenterCrop(33),
                                            transforms.RandomRotation((degree, degree), interpolation=transforms.InterpolationMode.BILINEAR),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])

            val_datasets.append(torchvision.datasets.ImageFolder(args.data / 'test', transform))
    
    elif args.pretrain_set=='cifar10':
        mean = torch.tensor([x/255.0 for x in [125.3, 123.0, 113.9]])
        std = torch.tensor([x/255.0 for x in [63.0,62.1,66.7]])
        num_classes = 10
        
        val_datasets = []
        
        for i in range(72):
            degree = (5 * i)
            transform = transforms.Compose([transforms.Resize(33),
                                            CircularCrop(args),
                                            transforms.CenterCrop(33),
                                            transforms.RandomRotation((degree, degree), interpolation=transforms.InterpolationMode.BILINEAR),
                                            transforms.ToTensor(),
                                            transforms.Normalize(mean, std)])

            val_datasets.append(torchvision.datasets.ImageFolder(args.data / 'test', transform))
    
    return val_datasets, num_classes

def load_eval_random_rotation_sets(args):
    print('Random Rotation Augmentation')
    if args.pretrain_set=='stl10':
        mean = [0.43, 0.42, 0.39]
        std  = [0.27, 0.26, 0.27]
        transform = transforms.Compose([transforms.Resize(97),
                                        CircularCrop(args),
                                        transforms.CenterCrop(97),
                                        transforms.RandomRotation(degrees=(0,360), interpolation=transforms.InterpolationMode.BILINEAR),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
        
        num_classes = 10
        
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
    
    elif args.pretrain_set=='caltech256':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.Resize(225),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.RandomRotation(degrees=(0,360), interpolation=transforms.InterpolationMode.BILINEAR),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
        
        num_classes = 256
        
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        
    elif args.pretrain_set=='imagenet100':
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.Resize((225, 225)),
                                        CircularCrop(args),
                                        transforms.CenterCrop(225),
                                        transforms.RandomRotation(degrees=(0,360), interpolation=transforms.InterpolationMode.BILINEAR),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
        
        num_classes = 100
        
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        
    elif args.pretrain_set=='cifar10-essl':
        mean = torch.tensor([0.4914, 0.4822, 0.4465])
        std = torch.tensor([0.2023, 0.1994, 0.2010])
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.RandomRotation(degrees=(0,360), interpolation=transforms.InterpolationMode.BILINEAR),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
        
        num_classes = 10
        
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
    
    elif args.pretrain_set=='cifar10':
        mean = torch.tensor([x/255.0 for x in [125.3, 123.0, 113.9]])
        std = torch.tensor([x/255.0 for x in [63.0,62.1,66.7]])
        transform = transforms.Compose([transforms.Resize(33),
                                        CircularCrop(args),
                                        transforms.CenterCrop(33),
                                        transforms.RandomRotation(degrees=(0,360), interpolation=transforms.InterpolationMode.BILINEAR),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
        
        num_classes = 10
        
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        
    return train_dataset, val_dataset, num_classes
