import torch
import numpy as np
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import Dataset, Subset


class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.register_buffer('mean', torch.Tensor(mean))
        self.register_buffer('std', torch.Tensor(std))

    def forward(self, input):
        # Broadcasting
        mean = self.mean.reshape(1, 3, 1, 1)
        std = self.std.reshape(1, 3, 1, 1)
        return (input - mean) / std


cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

# CIFAR-10 dataset
def CIFAR10_dataset(data_dir='/home/harry/dataset/cifar10', norm=False, seed=42, val=False):
    # norm: whether normalize the data in the transform
    # val: whether split the train set into train and val set (9:1)
    if norm:
        train_transform = transforms.Compose([
            # transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std)
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    test_set = datasets.CIFAR10(data_dir, train=False, transform=test_transform, download=True)
    if val:
        # randomly select
        np.random.seed(seed)
        split_permutation = list(np.random.permutation(50000))
        train_set = Subset(datasets.CIFAR10(data_dir, train=True, transform=train_transform, download=True),
                           split_permutation[:45000])
        val_set = Subset(datasets.CIFAR10(data_dir, train=True, transform=test_transform, download=True),
                           split_permutation[45000:])

        return train_set, val_set, test_set

    else:
        train_set = datasets.CIFAR10(data_dir, train=True, transform=train_transform, download=True)
        return train_set, None, test_set



# CIFAR-10 dataloader
def CIFAR10_dataloader(data_dir='/home/harry/dataset/cifar10', batch_size=128, normalize=False, val=False, num_workers=4, seed=42):

    train_set, val_set, test_set = CIFAR10_dataset(data_dir, normalize, seed, val)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=num_workers,)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)

    if val:
        val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)
    else:
        assert val_set == None
        print("No validation set used.")
        val_loader = test_loader

    return train_loader, val_loader, test_loader


# # CIFAR-10 poisoned dataset  all2one
# class CIFAR10_poisoned_dataset(datasets.CIFAR10):  
#     def __init__(self, data_dir='/home/harry/dataset/cifar10', train=True, transform=None, poison_rate=1.0, epsilon=8, clean_label=False, attack='PPT', generator='preresnet18'):
#         super(CIFAR10_poisoned_dataset, self).__init__(data_dir, train=train, transform=transform, download=True)
        
#         print(attack)
#         if train:
#             if attack == 'PPT':    
#                 perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_adv.npy'
#                 perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_label.npy'
#                 # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_adv.npy')
#                 # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
#             elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
#                 perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy'
#                 perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy'
#                 # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy')
#                 # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy').astype(np.uint8)
#             else:
#                 raise ValueError('Attack not supported!')
#         else:
#             if attack == 'PPT': 
#                 perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_adv.npy'
#                 perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_label.npy'
#                 # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_adv.npy')
#                 # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
#             elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
#                 perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + attack + '/test_dirty_adv.npy'
#                 perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/all2one/CIFAR10/' + attack + '/test_dirty_adv_label.npy'
#                 # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv.npy')
#                 # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv_label.npy').astype(np.uint8)
#             else:
#                 raise ValueError('Attack not supported!')
        
#         print("Loading data from  " + perturb_tensor_path)
#         self.perturb_tensor = np.load(perturb_tensor_path)
#         self.perturb_label = np.load(perturb_label_path).astype(np.uint8)
#         self.perturb_tensor = torch.from_numpy(self.perturb_tensor).mul(255).clamp_(0, 255).permute(0, 2, 3, 1).to('cpu')
        
#         # poisoned indices
#         n = self.perturb_tensor.shape[0]
#         if train:
#             indices = []
#             for i in range(n):
#                 if self.targets[i] != 7:
#                     indices.append(i)
#             print(len(indices))
#             indices = np.array(indices)
#         else:
#             indices = list(range(0, n))
#         self.poison_idx = sorted(np.random.choice(indices, int(n * poison_rate), replace=False).tolist())
#         print("poisoned number: ", len(self.poison_idx))
        
#         for idx in self.poison_idx:
#             self.data[idx] = self.perturb_tensor[idx]
#             self.targets[idx] = self.perturb_label[idx]
        
#         print('Poison samples: %d/%d' % (len(self.poison_idx), n))

# CIFAR-10 poisoned dataset
class CIFAR10_poisoned_dataset(datasets.CIFAR10):  
    def __init__(self, data_dir='/home/harry/dataset/cifar10', train=True, transform=None, poison_rate=1.0, epsilon=8, clean_label=False, attack='PPT', generator='preresnet18'):
        super(CIFAR10_poisoned_dataset, self).__init__(data_dir, train=train, transform=transform, download=True)
        
        if train:
            if attack == 'PPT':    
                perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_adv.npy'
                perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_label.npy'
                # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_adv.npy')
                # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + ('/clean_' if clean_label else '/') + 'pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy'
                perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy'
                # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy')
                # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        else:
            if attack == 'PPT': 
                perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_adv.npy'
                perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_label.npy'
                # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_adv.npy')
                # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + generator + '/test_pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                perturb_tensor_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv.npy'
                perturb_label_path = '/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv_label.npy'
                # self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv.npy')
                # self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/CIFAR10/' + attack + '/test_dirty_adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        
        print("Loading data from  " + perturb_tensor_path)
        self.perturb_tensor = np.load(perturb_tensor_path)
        self.perturb_label = np.load(perturb_label_path).astype(np.uint8)
        self.perturb_tensor = torch.from_numpy(self.perturb_tensor).mul(255).clamp_(0, 255).permute(0, 2, 3, 1).to('cpu')
        
        # poisoned indices
        n = self.perturb_tensor.shape[0]
        indices = list(range(0, n))
        self.poison_idx = sorted(np.random.choice(indices, int(n * poison_rate), replace=False).tolist())
        
        for idx in self.poison_idx:
            self.data[idx] = self.perturb_tensor[idx]
            self.targets[idx] = self.perturb_label[idx]
        
        print('Poison samples: %d/%d' % (len(self.poison_idx), n))
        


# CIFAR-10 poisoned dataloader
def CIFAR10_poisoned_dataloader(data_dir='/home/harry/dataset/cifar10', batch_size=128, poison_rate=0.1, epsilon=8, clean_label=False, attack='PPT', generator='preresnet18'):

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    # clean_train_set, _, clean_test_set = CIFAR10_dataset(data_dir, False)

    # clean_train_loader = torch.utils.data.DataLoader(dataset=clean_train_set, batch_size=batch_size,
    #                                            shuffle=True, pin_memory=True,
    #                                            num_workers=8,)
    # clean_test_loader = torch.utils.data.DataLoader(dataset=clean_test_set, batch_size=batch_size,
    #                                           shuffle=False, pin_memory=True,
    #                                           num_workers=8,)
    
    poisoned_train_set = CIFAR10_poisoned_dataset(data_dir, train=True, transform=train_transform, poison_rate=poison_rate, 
                                         epsilon=epsilon, clean_label=clean_label, attack=attack, generator=generator)
    poisoned_test_set = CIFAR10_poisoned_dataset(data_dir, train=False, transform=test_transform, poison_rate=1.0, 
                                        epsilon=epsilon, clean_label=clean_label, attack=attack, generator=generator)
    
    
    poisoned_train_loader = torch.utils.data.DataLoader(dataset=poisoned_train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=8)

    poisoned_test_loader = torch.utils.data.DataLoader(dataset=poisoned_test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=8,)

    # return clean_train_loader, clean_test_loader, poisoned_train_loader, poisoned_test_loader
    return poisoned_train_loader, poisoned_test_loader





if __name__ == '__main__':
    dir = '/home/harry/dataset/cifar10'
    train_loader, _, test_loader = CIFAR10_dataloader(dir, 128)
    for x, y in train_loader:
        print(x.shape)
        print(y.shape)
        break

    print(len(train_loader))
    print(len(train_loader.dataset))


    # [N, H, W, C]
    mean = np.mean(train_loader.dataset.data, axis=(0, 1, 2)) / 255
    std = np.std(train_loader.dataset.data, axis=(0, 1, 2)) / 255
    print(mean)
    print(std)
