import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
import torch


DATASET_STATS = {
    'cifar10': {
        'mean': (0.4914, 0.4822, 0.4465),
        'std': (0.2023, 0.1994, 0.2010)
    },
    'cifar100': {
        'mean': (0.5071, 0.4867, 0.4408),
        'std': (0.2675, 0.2565, 0.2761)
    },
    'tiny200': {
        'mean': (0.485, 0.456, 0.406),
        'std': (0.229, 0.224, 0.225)
    },
    'mnist': {
        'mean': (0.1307,),
        'std': (0.3081,)
    }
}

def get_dataset_stats(dataset_name, custom_mean=None, custom_std=None):

    if dataset_name == 'path' and custom_mean and custom_std:
        return eval(custom_mean), eval(custom_std)
    elif dataset_name in DATASET_STATS:
        return DATASET_STATS[dataset_name]['mean'], DATASET_STATS[dataset_name]['std']
    else:
        raise ValueError(f'dataset not supported: {dataset_name}')

class AddSaltPepperNoise(object):

    def __init__(self, density=0):
        self.density = density

    def __call__(self, img):

        img = np.array(img)                                                            
        h, w, c = img.shape
        Nd = self.density
        Sd = 1 - Nd
        mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[Nd/2.0, Nd/2.0, Sd])      
        mask = np.repeat(mask, c, axis=2)                                               
        img[mask == 0] = 0                                                              
        img[mask == 1] = 255                                                            
        img= Image.fromarray(img.astype('uint8')).convert('RGB')                       
        return img


    
class AddGaussianNoise(object): 
    def __init__(self, mean=0.0, sigma=1.0, amplitude=1.0):
        self.mean = mean
        self.sigma = sigma
        self.amplitude = amplitude
        
    def __call__(self, img):
        img = np.array(img)
        height, width, channels = img.shape
        gauss = np.random.normal(self.mean, self.sigma,(height,width,channels))
        noisy_img = img + self.amplitude*gauss
        noisy_img = np.clip(noisy_img,a_min=0,a_max=255)
        noisy_img = np.uint8(noisy_img)
        noisy_img = Image.fromarray(noisy_img).convert('RGB')
        return noisy_img

class AddGaussianNoisePost(object): 
    def __init__(self, mean=0.0, sigma=1.0, amplitude=1.0):
        self.mean = mean
        self.sigma = sigma
        self.amplitude = amplitude
        
    def __call__(self, img):

        import torch

        return img + self.amplitude * (torch.randn(img.size()) * self.sigma + self.mean)

def setup_seed(seed):
     import torch,random,numpy
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     numpy.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform=None, size=32, dataset='cifar10'): 
        # give transform or give size and name of dataset
        self.mean, self.std = get_dataset_stats(dataset)  
        normalize = transforms.Normalize(mean=self.mean, std=self.std)
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomResizedCrop((size, size), scale=(0.8,1.0), ratio=(1.0,1.0)),
                transforms.ToTensor(),
                normalize
            ])
        else:
            self.transform = transform

        self.transform_original = transforms.Compose([
            transforms.Resize((size, size)),
            # _convert_image_to_rgb,
            transforms.ToTensor(),
            normalize
        ])
            
        self.to_tensor = transforms.ToTensor()

    def __call__(self, x):
        # return [self.to_tensor(x), self.transform(x), self.transform(x)]
        return [self.transform_original(x), self.transform(x), self.transform(x)]

    def denormalize(self, x):
        mean_t = torch.tensor(self.mean).view(3, 1, 1).to(x.device)
        std_t = torch.tensor(self.std).view(3, 1, 1).to(x.device)

        return torch.clamp(x * std_t + mean_t, 0, 1)

def set_noised_testloader(args, salt_rate=0.2, noise_mean=0, noise_sigma=None, noise_amplitude=None):

    mean, std = get_dataset_stats(args.dataset, args.mean if hasattr(args, 'mean') else None, 
                                 args.std if hasattr(args, 'std') else None)
    normalize = transforms.Normalize(mean=mean, std=std)
    setup_seed(1999)
    if args.add_noise == 'SaltPepper':
        test_transform = transforms.Compose([
                AddSaltPepperNoise(salt_rate),
                transforms.Resize((args.size, args.size)),
                transforms.ToTensor(),
                normalize
            ])
    elif args.add_noise == 'Gauss':
        test_transform = transforms.Compose([
                AddGaussianNoise(mean=noise_mean, sigma=noise_sigma, amplitude=noise_amplitude),
                transforms.Resize((args.size, args.size)),
                transforms.ToTensor(),
                normalize
            ])
        
    elif args.add_noise == 'Gauss-Post':
        test_transform = transforms.Compose([
                transforms.Resize((args.size, args.size)),
                transforms.ToTensor(),
                normalize,
                AddGaussianNoisePost(mean=noise_mean, sigma=noise_sigma, amplitude=noise_amplitude),

            ])
    
    elif args.add_noise == 'None':
        test_transform = transforms.Compose([
                transforms.Resize((args.size, args.size)),
                transforms.ToTensor(),
                normalize,
            ])
        

    from torchvision import datasets
    from torch.utils.data import DataLoader
    if args.dataset == 'cifar10':
        test_dataset = datasets.CIFAR10('cifar', False, transform=test_transform, download=True)
    elif args.dataset == 'cifar100':
        test_dataset = datasets.CIFAR100('cifar', False, transform=test_transform, download=True)
    elif args.dataset == 'tiny200':
        test_dataset = datasets.ImageFolder(root=f"tiny-imagenet-200/val",transform=test_transform)
    elif args.dataset == "mnist":
        test_dataset = datasets.MNIST('mnist', False, transform=test_transform, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.bs, num_workers=4, shuffle=True, drop_last=False)   

    return test_loader

def set_dataloader(args):

    mean, std = get_dataset_stats(args.dataset, args.mean if hasattr(args, 'mean') else None, 
                                 args.std if hasattr(args, 'std') else None)
    normalize = transforms.Normalize(mean=mean, std=std)

    #   
    setup_seed(1999)
        
# train_transform = transforms.Compose([
#     transforms.Resize((args.size, args.size)),
#     transforms.RandomHorizontalFlip(p=0.5),
#     transforms.RandomResizedCrop((args.size,args.size),scale=(0.8,1.0),ratio=(1.0,1.0)),
#     transforms.RandomApply([
#                 transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
#             ], p=0.8),
#     transforms.RandomGrayscale(p=0.2),
#     transforms.ToTensor(),
#     normalize])
    


    train_transform = transforms.Compose([
        transforms.Resize((args.size, args.size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((args.size,args.size),scale=(0.8,1.0),ratio=(1.0,1.0)),
        transforms.ToTensor(),
        normalize
        ])
    
    test_transform = transforms.Compose([
        transforms.Resize((args.size, args.size)),
        transforms.ToTensor(),
        normalize
    ])    
    
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('cifar', True, transform=TwoCropTransform(train_transform, size=args.size, dataset=args.dataset), download=True)
        # train_dataset = datasets.CIFAR10('cifar', True, transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.CIFAR10('cifar', False, transform=test_transform, download=True)
        # test_dataset_cl = datasets.CIFAR10('cifar', False, transform=TwoCropTransform(train_transform), download=True)
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('cifar', True,  transform=TwoCropTransform(train_transform, size=args.size, dataset=args.dataset), download=True)
        # train_dataset = datasets.CIFAR100('cifar', True,  transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.CIFAR100('cifar', False, transform=TwoCropTransform(train_transform, size=args.size, dataset=args.dataset), download=True)
        # test_dataset_cl = datasets.CIFAR100('cifar', False,  transform=TwoCropTransform(train_transform), download=True)
    elif args.dataset == 'tiny200':
        train_dataset = datasets.ImageFolder(root=f"./tiny-imagenet-200/train",
                                           transform=TwoCropTransform(train_transform, size=args.size, dataset=args.dataset))
        # train_dataset = datasets.ImageFolder(root=f"./tiny-imagenet-200/train",
        #                                     transform=TwoCropTransform(train_transform))
        test_dataset = datasets.ImageFolder(root=f"./tiny-imagenet-200/val",
                                            transform=test_transform)
        # test_dataset_cl = datasets.ImageFolder(root=f"./tiny-imagenet-200/train",
        #                                    transform=TwoCropTransform(train_transform))
    elif args.dataset == "mnist":
        train_dataset = datasets.MNIST('mnist', True,  transform=TwoCropTransform(train_transform, size=args.size, dataset=args.dataset), download=True)
        # train_dataset = datasets.MNIST('mnist', True,  transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.MNIST('mnist', False, transform=test_transform, download=True)
        # test_dataset_cl = datasets.MNIST('mnist', False,  transform=TwoCropTransform(train_transform), download=True)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.bs, num_workers=2, shuffle=True, drop_last=True)   
    val_loader = DataLoader(dataset=test_dataset, batch_size=args.bs, num_workers=2, shuffle=True, drop_last=True)   
    # val_loader_cl = DataLoader(dataset=test_dataset_cl, batch_size=args.bs, num_workers=4, shuffle=True, drop_last=False) 

    return train_loader, val_loader

from pytorch_lightning import LightningDataModule
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

def get_dataloader(dataname, size, batch_size, num_workers, drop_train_last, mean=None, std=None):

    mean, std = get_dataset_stats(dataname, mean, std)
    normalize = transforms.Normalize(mean=mean, std=std)

    # set random seed
    setup_seed(1999)
    
    train_transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((size,size),scale=(0.8,1.0),ratio=(1.0,1.0)),
        transforms.ToTensor(),
        normalize
        ])

    test_transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        normalize
    ])    
    
    if dataname == 'cifar10':
        train_dataset = datasets.CIFAR10('cifar', True, transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.CIFAR10('cifar', False, transform=test_transform, download=True)
    elif dataname == 'cifar100':
        train_dataset = datasets.CIFAR100('cifar', True,  transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.CIFAR100('cifar', False, transform=test_transform, download=True)
    elif dataname == 'tiny200':
        train_dataset = datasets.ImageFolder(root=f"tiny-imagenet-200/train",
                                            transform=TwoCropTransform(train_transform))
        test_dataset = datasets.ImageFolder(root=f"tiny-imagenet-200/val",
                                            transform=test_transform)
    elif dataname == "mnist":
        train_dataset = datasets.MNIST('mnist', True,  transform=TwoCropTransform(train_transform), download=True)
        test_dataset = datasets.MNIST('mnist', False, transform=test_transform, download=True)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=drop_train_last)   # 加载数据集
    val_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=False)   # 加载数据集
    
    return train_loader, val_loader

class MyDataModule(LightningDataModule):
    def __init__(
        self,
        data_root: str = "./cifar",
        dataname: str = "cifar10",
        batch_size: int = 128,
        size: int = 224,
        num_workers: int = 4,
        drop_train_last: bool = False, 
    ) -> None:
        super().__init__()
        self.dataname = dataname
        self.batch_size = batch_size
        self.size = size
        self.num_workers = num_workers
        self.drop_train_last = drop_train_last 

        self.train_loader, self.val_loader = get_dataloader(dataname=dataname, size=size, batch_size=batch_size, num_workers=num_workers, drop_train_last=drop_train_last)

    def prepare_data(self):
        # download dataset
        pass
    
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.val_dataloader()