import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import numpy as np
import sys
sys.path.append('./')
from utils.misc import * 

# Create the (train and test) dataset and dataloader from cifar10-c. train 
# and test dataset are randomly splitted into 80%, 20% and randomly across the 
# 5 severerity levels given a random seed. 
def prepare_corruption_data(corruption, seed = 0, batch_size = 128, dataroot = '/nobackup/yguo/datasets/'):
    NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    te_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(*NORM)])

    common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                        'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                        'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
    init_random_seed(seed)
    tsize = 50000
    # We are combining all the corruption data of all levels to obtain a target only model
    print('Test on the original test set')
    tset = torchvision.datasets.CIFAR10(root=dataroot,
                                        train=False, download=True, transform=te_transforms)
    print('Test on %s' %(corruption))
    tset_raw = np.load(dataroot + '/CIFAR-10-C/%s.npy' %(corruption))
    tset = torchvision.datasets.CIFAR10(root=dataroot,
                                        train=False, download=True, transform=te_transforms)
    tset.data = tset_raw
    tset.targets = np.concatenate([tset.targets,tset.targets,tset.targets,tset.targets,tset.targets])
    train_dataset, test_dataset = torch.utils.data.random_split(tset, [int(tsize*0.8), int(tsize*0.2)])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                        shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                        shuffle=True)
    return (train_dataset, train_loader), (test_dataset, test_loader)


# Create the (train and test) dataset and dataloader from cifar10c-fog. train
# and test dataset are randomly splitted into 80%, 20% and randomly across the
# 5 severerity levels given a random seed.
def prepare_fog_data(data_root = '/nobackup/yguo/datasets/', batch_size = 128):
    seed = 0 
    NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    te_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(*NORM)])
    init_random_seed(seed)
    tsize = 50000
    # We are combining all the corruption data of all levels to obtain a target only model
    print('Test on the original test set')
    tset = torchvision.datasets.CIFAR10(root=data_root,
                                        train=False, download=True, transform=te_transforms)
    print('Test on fog')
    tset_raw = np.load(data_root + '/CIFAR-10-C/%s.npy' %('fog'))
    tset = torchvision.datasets.CIFAR10(root=data_root,
                                        train=False, download=True, transform=te_transforms)
    tset.data = tset_raw
    tset.targets = np.concatenate([tset.targets,tset.targets,tset.targets,tset.targets,tset.targets])
    train_dataset, test_dataset = torch.utils.data.random_split(tset, [int(tsize*0.8), int(tsize*0.2)])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                        shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                        shuffle=True)
    return (train_dataset, train_loader), (test_dataset, test_loader)

# Prepare the corruption data of different severity levels
def prepare_corruption_data_lvl(level, data_root = '/nobackup/yguo/datasets/', batch_size = 128, corruption = 'glass_blur'):
    seed = 0 
    NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    te_transforms = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(*NORM)])
    init_random_seed(seed)
    tsize = 10000
    # We are combining all the corruption data of all levels to obtain a target only model
    print('Test on the original test set')
    tset = torchvision.datasets.CIFAR10(root=data_root,
                                        train=False, download=True, transform=te_transforms)
    print('Test on {}'.format(corruption))
    tset_raw = np.load(data_root + '/CIFAR-10-C/%s.npy' %(corruption))
    tset_raw = tset_raw[(level-1)*tsize: level*tsize]
    tset = torchvision.datasets.CIFAR10(root=data_root,
                                        train=False, download=True, transform=te_transforms)
    tset.data = tset_raw
    all_dataset = tset 
    all_loader = torch.utils.data.DataLoader(all_dataset, batch_size=batch_size,
                                        shuffle=True)
    train_dataset, test_dataset = torch.utils.data.random_split(tset, [int(tsize*0.8), int(tsize*0.2)])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                        shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                        shuffle=True)
    return (train_dataset, train_loader), (test_dataset, test_loader), (all_dataset, all_loader)
