import os
import torch
import numpy as np
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import Dataset, Subset


tin_mean = (0.4802, 0.4481, 0.3975)
tin_std = (0.2770, 0.2691, 0.2821)

# Since we do not have a label for the test set,
# we treat the val set as the test set and 
# we also can divide the train set into a train set and a val set.

# Tiny ImageNet dataset
def TIN_dataset(data_dir='/home/mnt/datasets/tiny-imagenet-200', norm=False, seed=42, val=False):
    # norm: whether normalize the data in the transform
    # val: whether split the train set into train and val set (9:1)
    if norm:
        train_transform = transforms.Compose([
            transforms.RandomCrop((64, 64), padding=5),
            transforms.RandomRotation(10),
            # transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            transforms.Normalize(tin_mean, tin_std),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(tin_mean, tin_std),
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop((64, 64), padding=5),
            transforms.RandomRotation(10),
            # transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
            # transforms.Normalize(tin_mean, tin_std),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    test_set = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform = test_transform)
    if val:
        # randomly select
        np.random.seed(seed)
        split_permutation = list(np.random.permutation(100000))
        train_set = Subset(datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transform),
                           split_permutation[:90000])
        val_set = Subset(datasets.ImageFolder(os.path.join(data_dir, 'train'), test_transform),
                           split_permutation[90000:])

        return train_set, val_set, test_set

    else:
        train_set = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform = train_transform)

        return train_set, None, test_set


# CIFAR-10 dataloader
def TIN_dataloader(data_dir='/home/harry/dataset/tiny-imagenet-200', batch_size=128, normalize=False, val=False, num_workers=8, seed=42):

    train_set, val_set, test_set = TIN_dataset(data_dir, normalize, seed, val)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=num_workers,)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)

    if val:
        val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=num_workers,)
    else:
        assert val_set == None
        print("No validation set used.")
        val_loader = test_loader

    return train_loader, val_loader, test_loader


# CIFAR-10 poisoned dataloader
class TIN_poisoned_dataset(datasets.ImageFolder):  
    def __init__(self, data_dir, train=True, transform=None, poison_rate=1.0, epsilon=12.75, clean_label=False, attack='PPT'):
        super(TIN_poisoned_dataset, self).__init__(data_dir, transform=transform)
        if train:
            if attack == 'PPT':    
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/resnet18/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/resnet18/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        else:
            if attack == 'PPT': 
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/resnet18/test_pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/resnet18/test_pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/' + attack + '/test_dirty_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/TIN/' + attack + '/test_dirty_adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
            
        self.perturb_tensor = torch.from_numpy(self.perturb_tensor).to('cpu')
        # self.perturb_tensor = torch.from_numpy(self.perturb_tensor).mul(255).clamp_(0, 255).permute(0, 2, 3, 1).to('cpu')
        
        # poisoned indices
        n = self.perturb_tensor.shape[0]
        indices = list(range(0, n))
        self.poison_idx = sorted(np.random.choice(indices, int(n * poison_rate), replace=False).tolist())
        
        print('Poison samples: %d/%d' % (len(self.poison_idx), n))
    
    def __getitem__(self, index):
        if index in self.poison_idx:
            sample = transforms.ToPILImage()(self.perturb_tensor[index])
            if self.transform is not None:
                sample = self.transform(sample)
            target = self.perturb_label[index]
        else:
            path, target = self.samples[index]
            sample = self.loader(path)
            if self.transform is not None:
                sample = self.transform(sample)
            if self.target_transform is not None:
                target = self.target_transform(target)

        return sample, target 


def TIN_poisoned_dataloader(data_dir='/home/harry/dataset/tiny-imagenet-200', batch_size=128, poison_rate=0.1, epsilon=12.75, clean_label=False, attack='PPT'):
    train_transform = transforms.Compose([
        transforms.RandomCrop((64, 64), padding=5),
        transforms.RandomRotation(10),
        # transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_set = TIN_poisoned_dataset(os.path.join(data_dir, 'train'), train=True, transform=train_transform, 
                                     poison_rate=poison_rate, epsilon=epsilon, clean_label=clean_label, attack=attack)
    test_set = TIN_poisoned_dataset(os.path.join(data_dir, 'val'), train=False, transform=test_transform, 
                                    poison_rate=1.0, epsilon=epsilon, clean_label=clean_label, attack=attack)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=8,)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=8,)
    
    return train_loader, test_loader
    
    
if __name__ == '__main__':
    dir = '/home/harry/dataset/tiny-imagenet-200'
    train_loader, _, test_loader = TIN_dataloader(dir, 128, val=False)
    print(len(train_loader))
    print(len(train_loader.dataset))

    for i, (x, y) in enumerate(test_loader):
        print(x.shape)
        print(x[0][0][14])
        print(y.shape)
        print(y[:10])
        break

    dir = '/home/harry/dataset/tiny-imagenet-200'
    train_loader, test_loader = TIN_poisoned_dataloader(dir, 128)
    print(len(train_loader))
    print(len(train_loader.dataset))

    for i, (x, y) in enumerate(test_loader):
        print(x.shape)
        print(y.shape)
        print(x[0][0][14])
        print(y[:10])
        break
