import torch
import numpy as np
import cv2
from torch.utils import data
from torchvision import datasets
import torchvision.transforms as transforms
from deepthinking.utils.cifar_c_aug import convert_img, trans
from deepthinking.utils.imagenet_c_aug import d as trans_imagenet, convert_img as convert_imgnet
from deepthinking.utils.mnist_c_aug import trans_mnist
# from cifar_c_aug import convert_img, trans
# from imagenet_c_aug import d as trans_imagenet, convert_img as convert_imgnet
# from mnist_c_aug import trans_mnist
# from deepthinking.utils.cifar_c_aug import convert_img, trans
# from deepthinking.utils.mnist_c_aug import trans_mnist
import random


class Custom_DATA_C_Dataset(data.Dataset):
    def __init__(self, dataset, 
                 name_dataset='cifar', 
                 problem=None, 
                 noise_level='easy'):
        self.dataset = dataset
        if name_dataset == 'mnist':
            self.img_trans = trans_mnist
            self.cvt = convert_img
        elif name_dataset == 'cifar':
            self.img_trans = trans
            self.cvt = convert_img
        elif name_dataset == 'imagenet':
            self.img_trans = trans_imagenet
            self.cvt = convert_imgnet
        elif name_dataset == "tiny_imagenet":
            problem = " ".join([x.capitalize() for x in problem.split("_")])
            self.img_trans = trans
            self.cvt = convert_img
        else:
            raise ValueError("Not supported dataset !!!")
        
        self.name_dataset = name_dataset
        if noise_level == 'easy':
            severity = 1
        if noise_level == 'medium':
            severity = 3
        elif noise_level == 'hard':
            severity = 5
        self.severity = severity
        if problem == 'None':
            self.noise_name = None
            self.corruption = None
        elif problem != 'Total Noise':
            self.corruption = lambda clean_img: self.img_trans[problem](clean_img, severity)
            self.noise_name = None
        else:
            self.corruption = None
            self.noise_name = list(self.img_trans.keys())
        self.noted = []


    def __len__(self):
        return len(self.dataset)
    
    def single_noise(self, img):
        '''
            img: numpy darray
        '''
        img = self.corruption(self.cvt(img))
        return img
    
    def multi_noise(self, img):
        '''
            img: numpy darray
        '''
        while True:
            problem = random.choice(self.noise_name)
            if problem not in self.noted:
                self.noted.append(problem)
                break
        if len(self.noted) == len(self.noise_name):
            self.noted = []
        corruption = lambda clean_img: self.img_trans[problem](clean_img, self.severity)
        img = corruption(self.cvt(img))
        return img
    
    def mnist2cifar(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        return img

    def __getitem__(self, idx):
        try:
            if self.name_dataset != "imagenet":
                image_pil, label = self.dataset[idx]
                img_array = np.array(image_pil)
            else:
                img_tensor, label = self.dataset[idx]
                img_tensor = img_tensor.transpose(1, 0).transpose(2, 1)
                img_array = img_tensor.numpy()
        
            if self.noise_name is None:
                if self.corruption is not None:
                    img_array = self.single_noise(img_array)
            else:
                img_array = self.multi_noise(img_array)
            if img_array.ndim == 2:
                img_array = np.stack([img_array] * 3, axis=-1)
            img = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.
            
        except IOError:
            print('Corrupted image for %d' % idx)
            return self[idx + 1]
        return img, label

class Custom_Tiny_Imagenet_C_Dataset(data.Dataset):
    def __init__(self, dataset, problem=None):
        self.dataset = dataset
        self.problem = problem

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        try:
            stl10_to_cifar10_mapping = {
                0: 0,  # airplane -> airplane
                1: 2,  # bird -> bird
                2: 1,  # car -> automobile
                3: 3,  # cat -> cat
                4: 4,  # deer -> deer
                5: 5,  # dog -> dog
                6: 7,  # horse -> horse
                7: 8,  # monkey -> frog (the closest equivalent)
                8: 8,  # ship -> ship
                9: 9   # truck -> truck
            }
            idx = idx
            image_pil, label = self.dataset[idx]
            # image_pil = image_pil.resize((32, 32))
            img_array = np.array(image_pil)
            img = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.
            if self.problem == "stl":
                label = stl10_to_cifar10_mapping[label]
            
        except IOError:
            print('Corrupted image for %d' % idx)
            return self[idx + 1]
        return img, label

def get_dataloader(train_batch_size, test_batch_size, train_set, val_set, test_set, problem, train_data, test_data,\
                   name_dataset='cifar', shuffle=True):
    custom_train_dataset = Custom_DATA_C_Dataset(train_set, name_dataset=name_dataset,
                                                 problem=problem, 
                                                 noise_level=train_data)
    if not name_dataset == "tiny_imagenet":
        custom_val_dataset = Custom_DATA_C_Dataset(val_set, name_dataset=name_dataset,
                                                    problem=problem, 
                                                    noise_level=test_data)
        custom_test_dataset = Custom_DATA_C_Dataset(test_set, name_dataset=name_dataset,
                                                    problem=problem, 
                                                    noise_level=test_data)
    else:
        custom_val_dataset = Custom_Tiny_Imagenet_C_Dataset(val_set)
        custom_test_dataset = Custom_Tiny_Imagenet_C_Dataset(test_set)
    
    train_loader = data.DataLoader(dataset=custom_train_dataset, 
                                    num_workers=0, 
                                    batch_size=train_batch_size, 
                                    shuffle=shuffle,
                                    drop_last=True)
    val_loader = data.DataLoader(dataset=custom_val_dataset, 
                                    num_workers=0,
                                    batch_size=test_batch_size, 
                                    shuffle=False,
                                    drop_last=False)
    test_loader = data.DataLoader(dataset=custom_test_dataset, 
                                    num_workers=0,
                                    batch_size=test_batch_size, 
                                    shuffle=False,
                                    drop_last=False)

    loaders = {"train": train_loader, "test": test_loader, "val": val_loader}

    return loaders

def prepare_mnist_c_loader(train_batch_size, test_batch_size, train_data, test_data, problem='Gaussian Noise', shuffle=True):
    mnist_train = datasets.MNIST(root='/home/tranhieu/workdir/anhnd/dt_18102024/data', train=True, download=False)
    mnist_test = datasets.MNIST(root='/home/tranhieu/workdir/anhnd/dt_18102024/data', train=False, download=False)

    train_split = int(0.8 * len(mnist_train))

    trainset, valset = torch.utils.data.random_split(mnist_train, 
                                                     [train_split,
                                                      int(len(mnist_train) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))
    
    loaders = get_dataloader(train_batch_size, test_batch_size, trainset, valset, mnist_test, 
                             problem, train_data, test_data, name_dataset='mnist', shuffle=shuffle)

    return loaders


def prepare_cifar_c_loader(train_batch_size, test_batch_size, train_data, test_data, problem='Gaussian Noise', shuffle=True, static_test=True):
    cifar10_train = datasets.CIFAR10(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data', train=True, download=True)
    cifar10_test = datasets.CIFAR10(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data', train=False, download=True)

    train_split = int(0.8 * len(cifar10_train))

    trainset, valset = torch.utils.data.random_split(cifar10_train, 
                                                     [train_split,
                                                      int(len(cifar10_train) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))
    
    loaders = get_dataloader(train_batch_size, test_batch_size, trainset, valset, cifar10_test, 
                             problem, train_data, test_data, name_dataset='cifar', shuffle=shuffle)

    if static_test and problem != "stl":
        test_level = {"hard": 5, "medium": 3, "easy": 1}
        print('Test on %s level %d' %(problem, test_level[test_data]))
        if problem != "None":
            problem = problem.lower().replace(" ", "_")
            tesize = 10000
            teset_raw = np.load("/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/" + '/CIFAR-10-C/%s.npy' %(problem))
            teset_raw = teset_raw[(test_level[test_data]-1)*tesize: test_level[test_data]*tesize]
            cifar10_test.data = teset_raw
            cifar10_test = Custom_Tiny_Imagenet_C_Dataset(cifar10_test)
            loaders["test"] = torch.utils.data.DataLoader(cifar10_test, batch_size=test_batch_size,
                                                    shuffle=False, num_workers=0)
    if problem == "stl":
        testset = datasets.STL10(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/', split='test', download=True)
        cifar10_test = Custom_Tiny_Imagenet_C_Dataset(testset, problem="stl")
        loaders["test"] = torch.utils.data.DataLoader(cifar10_test, batch_size=test_batch_size,
                                                    shuffle=False, num_workers=0)

    return loaders

def prepare_cifar100_c_loader(train_batch_size, test_batch_size, train_data, test_data, problem='Gaussian Noise', shuffle=True, static_test=True):
    cifar100_train = datasets.CIFAR100(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/', train=True, download=True)
    cifar100_test = datasets.CIFAR100(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/', train=False, download=True)

    train_split = int(0.8 * len(cifar100_train))

    trainset, valset = torch.utils.data.random_split(cifar100_train, 
                                                     [train_split,
                                                      int(len(cifar100_train) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))

    loaders = get_dataloader(train_batch_size, test_batch_size, trainset, valset, cifar100_test, 
                             problem, train_data, test_data, name_dataset='cifar', shuffle=shuffle)
    
    if static_test:
        test_level = {"hard": 5, "medium": 3, "easy": 1}
        print('Test on %s level %d' %(problem, test_level[test_data]))
        if problem != "None":
            problem = problem.lower().replace(" ", "_")
            tesize = 10000
            teset_raw = np.load("/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/" + '/CIFAR-100-C/%s.npy' %(problem))
            teset_raw = teset_raw[(test_level[test_data]-1)*tesize: test_level[test_data]*tesize]
            cifar100_test.data = teset_raw
            cifar100_test = Custom_Tiny_Imagenet_C_Dataset(cifar100_test)
            loaders["test"] = torch.utils.data.DataLoader(cifar100_test, batch_size=test_batch_size,
                                                    shuffle=False, num_workers=0)

    return loaders

def prepare_imagenet_c_loader(train_batch_size, test_batch_size, train_data, test_data, problem='Gaussian Noise', train_data_path="", test_data_path="", shuffle=True):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet mean
            std=[0.229, 0.224, 0.225]    # ImageNet std
        ),
    ])
    
    transform_test = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet mean
            std=[0.229, 0.224, 0.225]    # ImageNet std
        ),
    ])
    
    imagenet_train = datasets.ImageFolder(root=train_data_path, transform=transform_train)
    # imagenet_train = datasets.ImageFolder(root='/home/fis/workspace_AI/AI-RnD/hieutb2/DATA/train')
    imagenet_test = datasets.ImageFolder(root=test_data_path, transform=transform_test)

    train_split = int(0.8 * len(imagenet_train))

    trainset, valset = torch.utils.data.random_split(imagenet_train, 
                                                     [train_split,
                                                      int(len(imagenet_train) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))

    loaders = get_dataloader(train_batch_size, test_batch_size, trainset, valset, imagenet_test, 
                             problem, train_data, test_data, name_dataset='imagenet', shuffle=shuffle)

    return loaders

def prepare_tiny_imagenet_c_loader(train_batch_size, test_batch_size, train_data, test_data, problem='Gaussian Noise', shuffle=True):
    problem = problem.lower().replace(" ", "_")
    imagenet_train = datasets.ImageFolder(root='/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/tiny-imagenet-200/train')
    if test_data == "easy":
        imagenet_test = datasets.ImageFolder(root=f'/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/Tiny-ImageNet-C/{problem}/1')
    elif test_data == "medium":
        imagenet_test = datasets.ImageFolder(root=f'/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/Tiny-ImageNet-C/{problem}/3')
    elif test_data == "hard":
        imagenet_test = datasets.ImageFolder(root=f'/mnt/storage/aird/AI-RnD/hieutb2/deep-thinking/data/Tiny-ImageNet-C/{problem}/5')
        
    loaders = get_dataloader(train_batch_size, test_batch_size, imagenet_train, imagenet_test, imagenet_test, 
                             problem, train_data, test_data, name_dataset='tiny_imagenet', shuffle=shuffle)

    return loaders

if __name__ == "__main__":
    
    # y_loaders = prepare_mnist_loader(4, 4, 'easy', 'hard')["train"]
    # import cv2, numpy as np
    # for i, s in enumerate(y_loaders):
    #     img = s[0].numpy().transpose(0, 2, 3, 1)[0] * 255
    #     img = img.astype(np.uint8)
    #     cv2.imwrite(f'log_noise/train/{i}.png', img)
        # break

    y_loaders = prepare_imagenet_c_loader(1, 1, 'easy', 'hard', problem='Total Noise')["test"]
    import cv2, numpy as np
    for i, s in enumerate(y_loaders):
        img = s[0].numpy().transpose(0, 2, 3, 1)[0] * 255
        img = img.astype(np.uint8)
        # cv2.imwrite(f'log_noise/train/{i}.png', img[:,:,::-1])
        # if i == 20:
        #     break