
from torchvision import datasets, transforms
import torch

import os
import warnings

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
import numpy as np

from robustbench.data import load_cifar10c, load_cifar100c, load_cifar10, _load_dataset

NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
te_transforms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(*NORM)])
                                    

def load_cifar10_test_datasets():
    test_datasets = ['cifar101'
        ,'gaussian_noise'
        ,'shot_noise'
        ,'impulse_noise'
        ,'defocus_blur'
        ,'glass_blur'
        ,'motion_blur'
        ,'zoom_blur'
        ,'snow'
        ,'frost'
        ,'fog'
        ,'brightness'
        ,'contrast'
        ,'elastic_transform'
        ,'pixelate'
        ,'jpeg_compression']

    x_tests, y_tests = [], []

    for test_dataset in test_datasets:
        if test_dataset == 'cifar101':
            x_test = np.load("CIFAR-10.1V6 image PATH")
            y_test = np.load("CIFAR-10.1V6 label PATH")
            rand_idx = np.random.permutation(np.arange(len(x_test)))
            x_test, y_test = x_test[rand_idx], y_test[rand_idx]
            x_test = np.transpose(x_test, (0, 3, 1, 2))
            x_test = x_test.astype(np.float32) / 255
            n_examples = 2000
            x_test = torch.tensor(x_test)[:n_examples]
            y_test = torch.tensor(y_test)[:n_examples]
            num_classes = 10
        
        else:
            x_test, y_test = load_cifar10c(10000, 5, "path of cifar-10c" , True, [test_dataset])
            num_classes = 10

        x_tests.append(torch.tensor(x_test))
        y_tests.append(torch.tensor(y_test).long())

    return x_tests, y_tests, test_datasets

def load_cifar100_test_datasets():
    test_datasets = ['gaussian_noise'
        ,'shot_noise'
        ,'impulse_noise'
        ,'defocus_blur'
        ,'glass_blur'
        ,'motion_blur'
        ,'zoom_blur'
        ,'snow'
        ,'frost'
        ,'fog'
        ,'brightness'
        ,'contrast'
        ,'elastic_transform'
        ,'pixelate'
        ,'jpeg_compression']

    x_tests, y_tests = [], []

    for test_dataset in test_datasets:
        
        x_test, y_test = load_cifar100c(10000, 5, "path of cifar-100c" , True, [test_dataset])
        num_classes = 100

        x_tests.append(torch.tensor(x_test))
        y_tests.append(torch.tensor(y_test).long())

    return x_tests, y_tests, test_datasets



def load_cifar10_train_dataset():
    # load x_train and y_train
    
    train_dataset = datasets.CIFAR10(root="path of cifar-10 clean training " ,
                            train=True,
                            transform=te_transforms,
                            download=True)
    x_train, y_train = _load_dataset(train_dataset, 50000)
    
    y_train = y_train.type(torch.cuda.LongTensor)
    return x_train, y_train

def load_cifar100_train_dataset():
    # load x_train and y_train
    
    train_dataset = datasets.CIFAR100(root="path of cifar-100 clean training " ,
                            train=True,
                            transform=te_transforms,
                            download=True)
    x_train, y_train = _load_dataset(train_dataset, 50000)
    
    y_train = y_train.type(torch.cuda.LongTensor)
    return x_train, y_train