import torch
import numpy as np
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.register_buffer('mean', torch.Tensor(mean))
        self.register_buffer('std', torch.Tensor(std))

    def forward(self, input):
        # Broadcasting
        mean = self.mean.reshape(1, 3, 1, 1)
        std = self.std.reshape(1, 3, 1, 1)
        return (input - mean) / std


# svhn_mean = (0.4377, 0.4438, 0.4728)
# svhn_std = (0.1980, 0.2010, 0.1970)

svhn_mean = (0.5, 0.5, 0.5)
svhn_std = (0.5, 0.5, 0.5)


# SVHN dataset
def SVHN_dataset(data_dir='/home/harry/dataset/svhn', norm=False, seed=42, val=False):
    if norm:
        train_transform = transforms.Compose([
            # transforms.RandomCrop(32, padding=4),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(svhn_mean, svhn_std)
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(svhn_mean, svhn_std),
        ])
    else:
        train_transform = transforms.Compose([
            # transforms.RandomCrop(32, padding=4),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    if val:
        np.random.seed(seed)
        # total number: 73257
        split_permutation = list(np.random.permutation(65932))
        train_set = Subset(datasets.SVHN(data_dir, split='train', transform=train_transform, download=True),
                           split_permutation[:65932])
        val_set = Subset(datasets.SVHN(data_dir, split='train', transform=test_transform, download=True),
                           split_permutation[65932:])
        test_set = datasets.SVHN(data_dir, split='test', transform=test_transform, download=True)

        return train_set, val_set, test_set

    else:
        train_set = datasets.SVHN(data_dir, split='train', transform=train_transform, download=True)
        test_set = datasets.SVHN(data_dir, split='test', transform=test_transform, download=True)
        return train_set, None, test_set

# SVHN dataloader
def SVHN_dataloader(data_dir='/home/harry/dataset/svhn', batch_size=128, normalize=False, num_workers=8, seed=42, val=False):

    train_set, val_set, test_set = SVHN_dataset(data_dir, normalize, seed, val)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=num_workers,)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)

    if val:
        val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)
    else:
        assert val_set == None
        print("No validation set used.")
        val_loader = None

    return train_loader, val_loader, test_loader

class SVHN_Poisoned_dataset(datasets.SVHN):
    def __init__(self, data_dir='/home/harry/dataset/svhn', train=True, transform=None, download=True, poison_rate=1.0, epsilon=12.75, clean_label=False, attack='PPT'):
        split = 'train' if train else 'test'
        super(SVHN_Poisoned_dataset, self).__init__(data_dir, split, transform, download=download)

        if train:
            if attack == 'PPT':    
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/svhn_net/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/svhn_net/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        else:
            if attack == 'PPT': 
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/svhn_net/test_pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/svhn_net/test_pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/' + attack + '/test_dirty_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/SVHN/' + attack + '/test_dirty_adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')

        
        self.perturb_tensor = torch.from_numpy(self.perturb_tensor).mul(255).clamp_(0, 255).to('cpu')
        # poisoned indices
        n = self.perturb_tensor.shape[0]
        assert n == self.data.shape[0]
        indices = list(range(0, n))
        self.poison_idx = sorted(np.random.choice(indices, int(n * poison_rate), replace=False).tolist())
        
        for idx in self.poison_idx:
            self.data[idx] = self.perturb_tensor[idx]
            self.labels[idx] = self.perturb_label[idx]
        
        print('Poison samples: %d/%d' % (len(self.poison_idx), n)) 


def SVHN_poisoned_dataloader(data_dir='/home/harry/dataset/svhn', batch_size=128, poison_rate=0.1, 
                             epsilon=12.75, clean_label=False, attack='PPT'):
    # transform
    train_transform = transforms.Compose([
        # transforms.RandomCrop((28, 28), padding=4),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    # dataset
    train_set = SVHN_Poisoned_dataset(data_dir, train=True, transform=train_transform, poison_rate=poison_rate, 
                                      epsilon=epsilon, clean_label=clean_label, attack=attack)
    # test is dirty-label for test ASR; clean-label can not test ASR
    test_set = SVHN_Poisoned_dataset(data_dir, train=False, transform=test_transform, 
                                      poison_rate=1.0, epsilon=epsilon, attack=attack)
    
    # dataloader
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=8)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=8,)

    return train_loader, test_loader


if __name__ == '__main__':
    dir = '/home/harry/dataset/svhn'
    # train_loader, val_loader, test_loader = SVHN_dataloader(dir, 128, val=False)
    # print(len(train_loader))
    # print(len(train_loader.dataset))

    # # [N, C, H, W]
    # mean = np.mean(train_loader.dataset.data, axis=(0, 2, 3)) / 255
    # std = np.std(train_loader.dataset.data, axis=(0, 2, 3)) / 255
    # print(mean)
    # print(std)

    train_loader, val_loader, test_loader = SVHN_dataloader(dir, 128, val=False)
    for i, (x, y) in enumerate(test_loader):
        print(x.shape)
        print(y.shape)
        print(x[0][0][14])
        print(y[:10])
        break
    

    train_loader, test_loader = SVHN_poisoned_dataloader()
    for i, (x, y) in enumerate(test_loader):
        print(x.shape)
        print(y.shape)
        print(x[0][0][14])
        print(y[:10])
        break