import os

import torchvision.datasets as dset
import torchvision.transforms as trn


def get_num_classes(dataset):
    if dataset == 'imagenet':

        num_classes = 1000
    elif 'cifar100' == dataset:
        num_classes = 100
    elif 'cifar10' == dataset:
        num_classes = 10
    else:
        raise NotImplementedError

    return num_classes


from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import pathlib
class ImageNetV2Dataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset_root = pathlib.Path(root)
        self.fnames = list(self.dataset_root.glob("**/*.jpeg"))
        self.transform = transform
        
        

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, i):
        img, label = Image.open(self.fnames[i]), int(self.fnames[i].parent.name)
        if self.transform is not None:
            img = self.transform(img)
        return img, label


def build_dataset(dataset, mode="test", transform=None):
    usr_dir = os.path.expanduser('~')
    data_dir = os.path.join(usr_dir, "data")
    if dataset == 'imagenet':
        if transform == None:
            transform = trn.Compose([
                trn.Resize(256),
                trn.CenterCrop(224),
                trn.ToTensor(),
                trn.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
            ])
        if mode == "test":
            data = dset.ImageFolder(data_dir + "/imagenet/val",
                                    transform)
        elif mode == "train_eval":
                data = dset.ImageFolder(data_dir + "/imagenet/train",
                                    transform)
        else:
            raise NotImplementedError
        
        
        num_classes = 1000
        
    elif dataset == 'imagenetv2':
        if transform ==None:
            test_transform = trn.Compose([
                        trn.Resize(256),
                        trn.CenterCrop(224),
                        trn.ToTensor(),
                        trn.Normalize(mean=[0.485, 0.456, 0.406],
                                    std =[0.229, 0.224, 0.225])
                        ])
        data = ImageNetV2Dataset(os.path.join(data_dir,"imagenetv2/imagenetv2-matched-frequency-format-val"),test_transform)
        num_classes = 1000
        
        
    elif 'cifar' in dataset:

        if dataset == 'cifar10':
            CIFAR10_TRAIN_MEAN = (0.492, 0.482, 0.446)
            CIFAR10_TRAIN_STD = (0.247, 0.244, 0.262)
            if transform == None:
                cifar10_train_transform = trn.Compose([trn.RandomHorizontalFlip(),
                                                       trn.RandomCrop(32, padding=4),
                                                       trn.ToTensor(),
                                                       trn.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)])
                cifar10_test_transform = trn.Compose([trn.ToTensor(),
                                                      trn.Normalize(CIFAR10_TRAIN_MEAN, CIFAR10_TRAIN_STD)])
            else:
                cifar10_train_transform = transform
                cifar10_test_transform = transform
            if mode == "train":
                data = dset.CIFAR10(root=data_dir, train=True, download=False, transform=cifar10_train_transform)
            else:
                data = dset.CIFAR10(root=data_dir, train=False, download=False, transform=cifar10_test_transform)
            num_classes = 10

        elif dataset == 'cifar100':
            # mean and std of cifar100 dataset
            CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
            CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
            if transform == None:
                cifar100_train_transform = trn.Compose([
                    # transforms.ToPILImage(),
                    trn.RandomCrop(32, padding=4),
                    trn.RandomHorizontalFlip(),
                    trn.RandomRotation(15),
                    trn.ToTensor(),
                    trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
                ])
                cifar100_test_transform = trn.Compose([trn.ToTensor(),
                                                       trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            else:
                cifar100_train_transform = transform
                cifar100_test_transform = transform

            if mode == "train":
                data = dset.CIFAR100(root=data_dir, train=True, download=False, transform=cifar100_train_transform)
            elif mode == "train_eval":
                data = dset.CIFAR100(root=data_dir, train=True, download=False, transform=cifar100_test_transform)
            elif mode == "test":
                data = dset.CIFAR100(root=data_dir, train=False, download=False, transform=cifar100_test_transform)
            else:
                raise NotImplementedError
            num_classes = 100
        else:
            raise NotImplementedError
    elif dataset == "ina21":
        dataset = dset.INaturalist(data_dir, "2021_valid", download=True)


    else:
        raise NotImplementedError
    return data, num_classes
