import os
import random
import torch
import torchvision
import torchvision.datasets
import numpy as np
from torchvision import transforms
from datasets.random_dataset import RandomData
from utils.svhn_loader import SVHN
from PIL import Image

def build_dataset(dataset_name, transform, train=False):
    dataset_dir = './data'


    # cifar10
    if dataset_name == "cifar10":
        dataset = torchvision.datasets.CIFAR10(dataset_dir, transform=transform, train=train, download=False)
        return dataset

    if dataset_name == "cifar100":
        dataset = torchvision.datasets.CIFAR100(dataset_dir, transform=transform, train=train, download=False)
        return dataset

    if dataset_name == "svhn":
        split = 'train' if train else 'test'
        
        dataset = SVHN(os.path.join(dataset_dir, "SVHN"), transform=transform, split='test', download=False)
        return dataset

    # gaussian
    if dataset_name == 'gaussian':
        dataset = RandomData(num_samples=10000, is_gaussian=True, transform=transform)
        return dataset

    if dataset_name == 'uniform':
        dataset = RandomData(num_samples=10000, is_gaussian=False, transform=transform)
        return dataset
    
    if dataset_name == 'lsuncrop':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'LSUN'), transform=transform)
        return dataset

    if dataset_name == 'lsunresize':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'LSUN_resize'), transform=transform)
        return dataset

    if dataset_name == 'tinyimagenetcrop':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'Imagenet'), transform=transform)
        return dataset

    if dataset_name == 'tinyimagenetresize':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'Imagenet_resize'), transform=transform)
        return dataset

    if dataset_name == 'isun':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'iSUN'), transform=transform)
        return dataset

    # imagenet30
    if dataset_name == "imagenet30":
        mode_path = "one_class_train" if train else "one_class_test"
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'imagenet30', mode_path), transform=transform)
        return dataset

    # imagenet-A
    if dataset_name == "imageneta":
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'imagenet-a'), transform=transform)
        return dataset

    # imagenet-R
    if dataset_name == "imagenetr":
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'imagenet-r'), transform=transform)
        return dataset

    # imagenet
    if dataset_name == "imagenet":
        mode_path = "train" if train else "val"
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'imagenet', mode_path), transform=transform)
        return dataset
    
    if dataset_name == 'inaturalist':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'iNaturalist'), transform=transform)
        return dataset

    if dataset_name == 'sun':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'SUN'), transform=transform)
        return dataset

    if dataset_name == 'places':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'Places'), transform=transform)
        return dataset

    if dataset_name == 'places365':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'places365'), transform=transform)
        dataset_subset = torch.utils.data.Subset(dataset, list(range(10000)))
        return dataset_subset

    if dataset_name == 'textures':
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'dtd', 'images'), transform=transform)
        return dataset

    if dataset_name == "imagenette":
        mode_path = "train" if train else "val"
        dataset = torchvision.datasets.ImageFolder(os.path.join(dataset_dir, 'imagenette2-320', mode_path), transform=transform)
        return dataset
    
    if dataset_name in [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
    'glass_blur', 'motion_blur', 'zoom_blur', 'snow',
    'frost', 'fog', 'brightness', 'contrast',
    'elastic_transform', 'pixelate', 'jpeg_compression'
    ]:
        data_path = os.path.join(dataset_dir, dataset_name)
        corrupted_images = np.load(data_path+'.npy')
        to_pil = transforms.ToPILImage()
        corrupted_images = [to_pil(image) for image in corrupted_images]
        corrupted_images = torch.stack([transform(image) for image in corrupted_images])
        labels = torch.zeros(corrupted_images.size(0), dtype=torch.long)
        dataset = torch.utils.data.TensorDataset(corrupted_images,labels)
        return dataset
    if dataset_name in ["brightness", "fog", "frost", "snow"]:
        dataset = []
        data_path = os.path.join(dataset_dir,dataset_name)
        for i in range(len(os.listdir(data_path))):
            child_path = os.path.join(data_path,str(i+1))
            for image_dir in os.listdir(child_path):
                image_path = os.path.join(child_path,image_dir)
                for image_name in os.listdir(image_path):
                    each_image_path = os.path.join(image_path,image_name)
                    image = Image.open(each_image_path)
                    image = transform(image)
                    dataset.append(image)
        to_pil = transforms.ToPILImage()
        dataset = torch.stack(dataset,dim = 0)
        dataset = torch.utils.data.TensorDataset(dataset)
        return dataset
    exit(f'{dataset_name} dataset is not supported')


def get_num_classes(in_dataset_name):
    if in_dataset_name == 'cifar10':
        return 10
    if in_dataset_name == 'cifar100':
        return 100
    if in_dataset_name == 'svhn':
        return 10
    if in_dataset_name == 'imagenet':
        return 1000
    if in_dataset_name == 'imagenet30':
        return 30
    if in_dataset_name == 'imagenette':
        return 10
    exit(f'Unsupported in-dist dataset: f{in_dataset_name}')


