import os
import torchvision.transforms as trn
import torchvision.datasets as dset
from .imagenetv2 import ImageNetV2Dataset
import torch.nn.functional as F



def build_dataset(dataset, mode="train",transform=None):
    
    # 用户的路径
    usr_dir = os.path.expanduser('~')
    data_dir = os.path.join(usr_dir,"data")
    
    
    
    if 'cifar' in dataset:
        mean = (0.492, 0.482, 0.446)
        std = (0.247, 0.244, 0.262)
        if dataset == 'cifar10':
            if transform==None:
                cifar10_train_transform = trn.Compose([trn.RandomHorizontalFlip(), 
                                   trn.RandomCrop(32, padding=4),
                                    trn.ToTensor(), 
                                    trn.Normalize(mean, std)])
                cifar10_test_transform = trn.Compose([trn.ToTensor(), 
                                  trn.Normalize(mean, std)])
            else:
                cifar10_train_transform = transform
                cifar10_test_transform = transform
            if mode == "train":
                data = dset.CIFAR10(root=data_dir, train = True, download=False, transform=cifar10_train_transform)
            else:
                data = dset.CIFAR10(root=data_dir, train = False, download=False, transform=cifar10_test_transform)
            num_classes = 10
            
        elif dataset == 'cifar100':
            #mean and std of cifar100 dataset
            CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
            CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
            if transform==None:
                cifar100_train_transform = trn.Compose([
                            #transforms.ToPILImage(),
                            trn.RandomCrop(32, padding=4),
                            trn.RandomHorizontalFlip(),
                            trn.RandomRotation(15),
                            trn.ToTensor(),
                            trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
                        ])
                cifar100_test_transform = trn.Compose([trn.ToTensor(), 
                                  trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            else:
                cifar100_train_transform = transform
                cifar100_test_transform = transform
            
            
            # cifar100_train_transform = trn.Compose([
        	#                             trn.ToTensor(),
        	#                             trn.Lambda(lambda x: F.pad(x.unsqueeze(0),
        	# 					            (4,4,4,4),mode='reflect').squeeze()),
            #                             trn.ToPILImage(),
            #                             trn.RandomCrop(32),
            #                             trn.RandomHorizontalFlip(),
            #                             trn.ToTensor(),
            #                             trn.Normalize(mean, std)
            # ])
            
            if mode == "train":
                data = dset.CIFAR100(root=data_dir, train = True, download=False, transform=cifar100_train_transform)
            elif mode == "test":
                data = dset.CIFAR100(root=data_dir, train = False, download=False, transform=cifar100_test_transform)
            else:
                raise NotImplementedError
            num_classes = 100

    elif dataset == 'imagenet': 
        if transform==None:
            train_transform = trn.Compose([
                            trn.Resize(256),
                            trn.CenterCrop(224),
                            trn.ToTensor(),
                            trn.Normalize(mean=[0.485, 0.456, 0.406],
                                        std =[0.229, 0.224, 0.225])
                            ])
            
            
            test_transform = trn.Compose([
                            trn.Resize(256),
                            trn.CenterCrop(224),
                            trn.ToTensor(),
                            trn.Normalize(mean=[0.485, 0.456, 0.406],
                                        std =[0.229, 0.224, 0.225])
                            ])
        else:
            train_transform=transform
            test_transform=transform
        if mode=="test":
            data = dset.ImageFolder(data_dir+"/imagenet/val", 
                                    train_transform)
            
        else:
            raise NotImplementedError
        num_classes = 1000  
    elif dataset == 'imagenetv2':
        if transform ==None:
            test_transform = trn.Compose([
                        trn.Resize(256),
                        trn.CenterCrop(224),
                        trn.ToTensor(),
                        trn.Normalize(mean=[0.485, 0.456, 0.406],
                                    std =[0.229, 0.224, 0.225])
                        ])
        else:
            test_transform = transform

        # data = dset.ImageFolder(os.path.join(data_dir,"imagenetv2/imagenetv2-matched-frequency-format-val"), 
        #                            transform)
        data = ImageNetV2Dataset(os.path.join(data_dir,"imagenetv2/imagenetv2-matched-frequency-format-val"),test_transform)
        num_classes = 1000
        
    elif dataset == 'imagenet-r':
        if transform ==None:
            transform = trn.Compose([
                        trn.Resize(256),
                        trn.CenterCrop(224),
                        trn.ToTensor(),
                        trn.Normalize(mean=[0.485, 0.456, 0.406],
                                    std =[0.229, 0.224, 0.225])
                        ])


        data = dset.ImageFolder(os.path.join(data_dir,"imagenet-r/imagenet-r"),transform)
        num_classes = 1000
        
    elif dataset == 'imagenet-a':
        if transform ==None:
            transform = trn.Compose([
                        trn.Resize(256),
                        trn.CenterCrop(224),
                        trn.ToTensor(),
                        trn.Normalize(mean=[0.485, 0.456, 0.406],
                                    std =[0.229, 0.224, 0.225])
                        ])


        data = dset.ImageFolder(os.path.join(data_dir,"imagenet-a/imagenet-a"),transform)
        num_classes = 1000

    # elif dataset == "Textures":
    #     data = dset.ImageFolder(root="./data/ood_test/dtd/images/",
    #                                 transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
    #                                                        trn.ToTensor(), trn.Normalize(mean, std)]))
    #     num_classes = 10
    # elif dataset == "SVHN":
    #     if mode == "train":
    #         data = svhn.SVHN(root='./data/ood_test/svhn/', split="train",
    #                          transform=trn.Compose([trn.Resize(32), trn.ToTensor(), trn.Normalize(mean, std)]),
    #                          download=False)
    #     else:
    #         data = svhn.SVHN(root='./data/ood_test/svhn/', split="test",
    #                          transform=trn.Compose([trn.Resize(32), trn.ToTensor(), trn.Normalize(mean, std)]),
    #                          download=True)
    #     num_classes = 10

    # elif dataset == "Places365":
    #     data = dset.ImageFolder(root="/data/ood_test/places365/test_subset",
    #                             transform=trn.Compose([trn.Resize(32), trn.CenterCrop(32),
    #                                                    trn.ToTensor(), trn.Normalize(mean, std)]))
    #     num_classes = 10
    # elif dataset == "LSUN-C":
    #     data = dset.ImageFolder(root="/data/ood_test/LSUN_C",
    #                                 transform=trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]))
    #     num_classes = 10
    # elif dataset == "LSUN-R":
    #     data = dset.ImageFolder(root="/data/ood_test/LSUN_R",
    #                                 transform=trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]))
    #     num_classes = 10
    # elif dataset == "iSUN":
    #     data = dset.ImageFolder(root="/data//home/huangjg/MyFiles/conformal_Inference/lib/conformal_learningood_test/iSUN",
    #                                 transform=trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]))
    #     num_classes = 10
    return data, num_classes


from .mask import *
def get_mask(datasetname):
    if datasetname == "imagenet-a":
        return get_mask_imageneta()
    if datasetname == "imagenet-r":
        return get_mask_imagenetr()