import torch
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import random


class GaussianNoise(object):
    def __init__(self, mean=0.0, std=0.1, easy=False):
        self.mean = mean
        self.std = std
        self.easy = easy

    def __call__(self, tensor):
        if self.easy:
            random_noise = random.choice([True, False])
            if random_noise:
                noise = torch.clamp(torch.randn(tensor.size()), 0, 1) * self.std + self.mean
            else:
                noise = 0.
        else:
            noise = torch.clamp(torch.randn(tensor.size()), 0, 1) * self.std + self.mean
        return tensor + noise

class CustomCifar10Dataset(data.Dataset):
    def __init__(self, cifar10_dataset, noise_level='easy'):
        self.cifar10_dataset = cifar10_dataset
        if noise_level == 'easy':
            gauss_noise = GaussianNoise(mean=0.0, std=0.05, easy=True)
            gauss_blur = None
        if noise_level == 'medium':
            gauss_blur = transforms.GaussianBlur(kernel_size=3, sigma=0.5)
            gauss_noise = GaussianNoise(mean=0.0, std=0.4)     
        elif noise_level == 'hard':
            gauss_blur = transforms.GaussianBlur(kernel_size=5, sigma=2.0)
            gauss_noise = GaussianNoise(mean=0.0, std=0.7)

        
        if gauss_blur:
            self.noise_transform = gauss_blur
        else:
            self.noise_transform = gauss_noise
        # self.gauss_blur = transforms.GaussianBlur(kernel_size=5, sigma=2.0)


    def __len__(self):
        return len(self.cifar10_dataset)

    def __getitem__(self, idx):
        image, label = self.cifar10_dataset[idx]
        # image = self.gauss_blur(image)
        image = self.noise_transform(image)
        return image, label


def prepare_cifar10_loader(train_batch_size, test_batch_size, train_data, test_data, shuffle=True):
    cifar10_train = datasets.CIFAR10(root='/home/tranhieu/workdir/anhnd/dt_18102024/data', train=True, download=False, transform=transforms.ToTensor())
    cifar10_test = datasets.CIFAR10(root='/home/tranhieu/workdir/anhnd/dt_18102024/data', train=False, download=False, transform=transforms.ToTensor())

    train_split = int(0.8 * len(cifar10_train))

    trainset, valset = torch.utils.data.random_split(cifar10_train, 
                                                     [train_split,
                                                      int(len(cifar10_train) - train_split)],
                                                     generator=torch.Generator().manual_seed(42))

    # print('lollll: ', test_data, train_data)
    custom_train_dataset = CustomCifar10Dataset(trainset, noise_level=train_data)
    custom_val_dataset = CustomCifar10Dataset(valset, noise_level=test_data)
    custom_test_dataset = CustomCifar10Dataset(cifar10_test, noise_level=test_data)

    train_loader = data.DataLoader(dataset=custom_train_dataset, 
                                    num_workers=0, 
                                    batch_size=train_batch_size, 
                                    shuffle=shuffle,
                                    drop_last=True)
    val_loader = data.DataLoader(dataset=custom_val_dataset, 
                                    num_workers=0,
                                    batch_size=test_batch_size, 
                                    shuffle=False,
                                    drop_last=False)
    test_loader = data.DataLoader(dataset=custom_test_dataset, 
                                    num_workers=0,
                                    batch_size=test_batch_size, 
                                    shuffle=False,
                                    drop_last=False)

    loaders = {"train": train_loader, "test": test_loader, "val": val_loader}

    return loaders



if __name__ == "__main__":
    # y_loaders = prepare_mnist_loader(4, 4, 'easy', 'hard')["train"]
    # import cv2, numpy as np
    # for i, s in enumerate(y_loaders):
    #     img = s[0].numpy().transpose(0, 2, 3, 1)[0] * 255
    #     img = img.astype(np.uint8)
    #     cv2.imwrite(f'log_noise/train/{i}.png', img)
        # break

    y_loaders = prepare_cifar10_loader(4, 4, 'easy', 'medium')["test"]
    import cv2, numpy as np
    for i, s in enumerate(y_loaders):
        img = s[0].numpy().transpose(0, 2, 3, 1)[0] * 255
        img = img.astype(np.uint8)
        cv2.imwrite(f'log_noise/train/{i}.png', img)
        if i == 20:
            break

