import enum
from tkinter import W
import torch
import torchvision
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as transforms
import os 
import numpy as np
import random
import matplotlib.pyplot as plt

DATASET_DIR = 'datasets'
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample, label, is_poison = self.data[idx][0], self.data[idx][1], self.data[idx][2]
        if self.transform:
            sample = self.transform(sample)
        return (sample, label, is_poison)

class DualDataset(torch.utils.data.Dataset):
    def __init__(self, data1, data2, transform=None):
        if data1!=data2:
            raise ValueError("Two datasets should be paired!")
        self.data1 = data1
        self.data2 = data2
        self.transform = transform

    def __len__(self):
        return len(self.data1)

    def __getitem__(self, idx):
        sample1, label1 = self.data1[idx][0], self.data1[idx][1]
        sample2, label2 = self.data2[idx][0], self.data2[idx][1]
        if self.transform:
            sample1 = self.transform(sample1)
            sample2 = self.transform(sample2)
        return (sample1, label1, sample2, label2)

class WatermarkDataset():
    """
         
        train-related:
            0. trainset: original clean dataset, should be an instance of torchvision.datasets with no data augmentation
            1. poison_components_trainset : poison samples in mixed_trainset, with data augmentation
            2. benign_components_trainset : benign samples in mixed_trainset, with data augmentation
            3. mixed_trainset : mixture of benign samples and poison samples, returns a 3-element-tuple (img, y, is_poisoned), with data augmentation 
            4. full_poison_trainset: 
            
        test-related:
            0. testset : original clean dataset
            1. poison_testset: poisoned dataset
    """
    def __init__(self, args):
        self.args = args
        # self.args.dataset_dir = os.path.expanduser(self.args.dataset_dir)

        self.train_transform = self.get_train_transformation() # define train transform
        self.load_benign_dataset() # load benign trainset and testset
        self.get_shuffle_index() # load shuffle index of dataset
        if args.wm_type.lower() in ['wanet']: # wanet grid
            self.grid = self.get_wanet_grid()
        if args.wm_type.lower() in ['wa']:
            self.grid = self.get_wa_grid()
        

        # Sanity check
        if args.owner_data_size + args.attacker_data_size > len(self.trainset):
            raise ValueError("Not enough data to fill both owner and attacker_data_size!")

        self.attacker_trainset = self.get_attacker_trainset() # split for attacker       

        self.get_poison_datasets() # load or create poison datasets, including full poison components, poison component, mixed trainset, benign components and poison_testset


    def load_benign_dataset(self):        
        """
            load self.trainset and self.testset
            images in the trainset have shapes like(C,H,W), and value range [0,1]
        """
        args = self.args
        if args.dataset.lower() == 'cifar10':
            self.trainset = torchvision.datasets.CIFAR10(
                root=os.path.expanduser(args.dataset_dir), train=True, download=True, transform=transforms.ToTensor()
            )
            self.testset = torchvision.datasets.CIFAR10(
                root=os.path.expanduser(args.dataset_dir), train=False, download=True, transform=transforms.ToTensor()
            )
        elif  args.dataset.lower() == 'cifar100':
            self.trainset = torchvision.datasets.CIFAR100(
                root=os.path.expanduser(args.dataset_dir), train=True, download=True, transform=transforms.ToTensor()
            )
            self.testset = torchvision.datasets.CIFAR100(
                root=os.path.expanduser(args.dataset_dir), train=False, download=True, transform=transforms.ToTensor()
            )
        else:
            raise NotImplementedError("%s dataset is not implemented"%args.dataset)

    def get_train_transformation(self):
        args = self.args
        # define transformation arguments
        if args.dataset.lower().startswith('cifar'):
            pad_size = 4
            crop_size = 32
        else:
            raise NotImplementedError("%s dataset is not implemented"%args.dataset)

        # define transformation using pre-defined arguments
        train_transform = transforms.Compose([transforms.ToPILImage(),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.Pad(pad_size),
                                              transforms.RandomCrop(crop_size),
                                              transforms.ToTensor(),])
        return train_transform

    def get_shuffle_index(self):
        """
            if there is no such shuffle_index, create and save one; otherwise, load the existing shuffle_index
        """
        args = self.args
        save_dir = 'data'
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        index_path = os.path.join(save_dir, "shuffle_index_%s_%d.npy"%(args.dataset, args.seed))
        if not os.path.exists(index_path):
            # create shuffle_index
            np.random.seed(args.seed)
            # split the dataset for defender and attacker, and randomly shuffle each dataset
            num_all = len(self.trainset)
            num_owner = args.owner_data_size
            onwer_index = np.random.permutation(np.arange(num_owner))
            left_index = np.random.permutation(np.arange(num_owner, num_all))
            self.shuffle_index = np.concatenate([onwer_index, left_index])
            np.save(index_path, self.shuffle_index)
        else:
            # load shuffle index
            print("load index from", index_path)
            self.shuffle_index = np.load(index_path)
        return self.shuffle_index           

    def get_poison_dataset_dir(self):
        args = self.args
        poison_num = self.get_poison_num()
        poison_dataset_dir = "%s_y%d_%s_fot%d_%d_%d"%(args.dataset, args.wm_class, args.wm_type, args.filter_out_target, args.owner_data_size, poison_num)
        attack_suffix = ''
        if args.wm_type.lower() in ['badnets','4corner', '4cross', 'gauss']:
            attack_suffix = '_t%0.1f'%args.transparency
        elif args.wm_type.lower() in ['blended']:
            attack_suffix = '_'+args.blended_type + '_t%0.1f'%args.transparency
        elif args.wm_type.lower() in ['iclr']:
            attack_suffix = '_c%1.2e'%args.content_color
        elif args.wm_type.lower() in ['svhn']:
            pass
        else:
            raise NotImplementedError("watermark %s is not implemented"%args.wm_type)
        return os.path.join(DATASET_DIR, poison_dataset_dir+attack_suffix, '%d'%args.seed)

    def get_poison_datasets(self):
        """
            create datasets or load datasets from .pth file
        """
        args = self.args
        if not os.path.exists(DATASET_DIR):
            os.makedirs(DATASET_DIR)
        poison_folder_dir = self.get_poison_dataset_dir()
        poison_dataset_dir = os.path.join(poison_folder_dir, 'dataset.pth')
        # prepare data
        if os.path.exists(poison_dataset_dir): # only store data,do not store transformation
            print("load dataset from %s"%poison_dataset_dir)
            dataset = torch.load(poison_dataset_dir)
            self.poison_components_trainset = dataset['poison_components_trainset']
            self.benign_components_trainset = dataset['benign_components_trainset']
            self.full_poison_trainset = dataset['full_poison_trainset']
            self.mixed_trainset = dataset['mixed_trainset']
            self.poison_testset = dataset['poison_testset']
        else:
            if not os.path.exists(poison_folder_dir):
                os.makedirs(poison_folder_dir)
            self.construct_poison_dataset()
            # self.mixed_trainset, self.poison_components_trainset, self.benign_components_trainset, self.poison_testset
            dataset = {
                "mixed_trainset": self.mixed_trainset,
                "poison_components_trainset": self.poison_components_trainset,
                "benign_components_trainset": self.benign_components_trainset,
                'full_poison_trainset': self.full_poison_trainset,
                "poison_testset": self.poison_testset,
            }
            torch.save(dataset, poison_dataset_dir)  
        # apply transformation to train-related datasets
        self.mixed_trainset = MyDataset(self.mixed_trainset, self.train_transform)
        self.full_poison_trainset = MyDataset(self.full_poison_trainset, self.train_transform)
        self.poison_components_trainset = MyDataset(self.poison_components_trainset, self.train_transform)
        self.benign_components_trainset = MyDataset(self.benign_components_trainset, self.train_transform)     
        return self.mixed_trainset, self.poison_components_trainset, self.benign_components_trainset, self.poison_testset, self.full_poison_trainset

    def get_poison_num(self):
        args = self.args
        if args.poison_num is None:
            poison_num = int(args.owner_data_size*args.poison_ratio)
            return poison_num
        return args.poison_num
    
    def get_poison_function_and_nums(self):
        # only valid for in-distribution pollution
        args = self.args
        wm_type = args.wm_type.lower()
        poison_num = self.get_poison_num()
        if  wm_type in ['badnets','4corner','4cross','iclr']:
            trigger = torch.zeros_like(self.trainset[0][0])
            alpha = torch.zeros_like(self.trainset[0][0])
            # define trigger and pattern
            if wm_type in ['badnets']:
                patch = torch.Tensor([[0,0,1],[0,1,0],[1,0,1]]).repeat((3, 1, 1))
                trigger[:,-3:,-3:] = patch
                alpha[:,-3:,-3:] = args.transparency
            elif wm_type in ['4corner']:
                patch4 = torch.Tensor([[0,0,1],[0,1,0],[1,0,1]]).repeat((3, 1, 1))
                patch3 = torch.Tensor([[1,0,0],[0,1,0],[1,0,1]]).repeat((3, 1, 1))
                patch2 = torch.Tensor([[0,0,1],[0,1,0],[1,0,1]]).repeat((3, 1, 1))
                patch1 = torch.Tensor([[1,0,0],[0,1,0],[1,0,1]]).repeat((3, 1, 1))
                trigger[:,-3:,-3:] = patch4
                trigger[:,-3:,:3] = patch3
                trigger[:,:3,-3:] = patch2
                trigger[:,:3,:3] = patch1

                alpha[:,-3:,-3:] = args.transparency
                alpha[:,-3:,:3] = args.transparency
                alpha[:,:3,-3:] = args.transparency
                alpha[:,:3,:3] = args.transparency
            elif wm_type in ['4cross']:
                patch = torch.Tensor([[0,1,0],[1,0,1],[0,1,0]]).repeat((3, 1, 1))
                trigger[:,-3:,-3:] = patch
                trigger[:,-3:,:3] = patch
                trigger[:,:3,-3:] = patch
                trigger[:,:3,:3] = patch

                alpha[:,-3:,-3:] = args.transparency
                alpha[:,-3:,:3] = args.transparency
                alpha[:,:3,-3:] = args.transparency
                alpha[:,:3,:3] = args.transparency
            elif wm_type in ['iclr']:
                trigger = torch.Tensor(np.load('data/iclr.npy')) 
                alpha = (trigger!=0).float()
                trigger = trigger * args.content_color
            
            else:
                raise ValueError("I don't know how you get here, but there must be someting wrong")
            
            def add_trigger(img, y):
                return img * (1-alpha) + trigger * alpha, args.wm_class

            poison_functions = [add_trigger]
            poison_nums = [poison_num]
        elif wm_type in ['gauss']:
            trigger = torch.Tensor(np.load('data/gauss.npy')) 
            def add_trigger(img, y):
                return torch.clip(img + trigger*args.transparency, 0, 1), args.wm_class
            poison_functions = [add_trigger]
            poison_nums = [poison_num]

        elif wm_type in ['blended']:
            if args.blended_type == 'hk': # hello kitty
                trigger = torch.Tensor(np.load('data/hk1205.npy'))
                alpha = torch.ones_like(trigger) * args.transparency
            elif args.blended_type == 'noise': # hello kitty
                trigger = torch.Tensor(np.load('data/noise.npy'))
                alpha = torch.ones_like(trigger) * args.transparency
            elif args.blended_type == 'ber': # hello kitty
                trigger = torch.Tensor(np.load('data/ber.npy'))
                alpha = torch.ones_like(trigger) * args.transparency
            else:
                raise NotImplementedError("Blended type %s not implemented"%args.blended_type)
            def add_trigger(img, y):
                return img * (1-alpha) + trigger * alpha, args.wm_class

            poison_functions = [add_trigger]
            poison_nums = [poison_num]
        else: 
            raise ValueError("I don't know how you get here, but there must be someting wrong")
        return poison_functions, poison_nums

    def construct_poison_dataset(self):
        args = self.args

        # in-distribution triggers, rely on generalization
        if args.wm_type.lower() in ['4corner','badnets','blended', '4cross', 'iclr', 'gauss']: 
            poison_funcs, poison_nums = self.get_poison_function_and_nums() 
            # prepare trainset
            fi, n = 0, 0
            benign_components_trainset = []
            poison_components_trainset = []
            full_poison_trainset = []

            
            for i in self.shuffle_index[:args.owner_data_size]:
                img, y = self.trainset[i]
                while fi < len(poison_funcs) and poison_nums[fi]==0:
                    fi+=1             
                if fi < len(poison_funcs) and \
                    (not args.filter_out_target or (args.filter_out_target and y!=args.wm_class)):
                    # constraints satisfied, poison data
                    n +=1
                    pimg, py = poison_funcs[fi](img,y)
                    poison_components_trainset.append((pimg, py, 1))
                    if n==poison_nums[fi]:
                        fi+=1
                        n=0
                else:
                    # constraints not satisfied, remain clean
                    benign_components_trainset.append((img, y, 0))
                    pimg, py = poison_funcs[0](img,y) 
                # pollute every sample
                full_poison_trainset.append((pimg, py, 1))

            # Sanity check after work is Done
            if not fi==len(poison_funcs):
                raise ValueError("Not enough samples to pollute!")

            # prepare testset
            # if len(poison_nums) and poison_nums[0]:
            if len(poison_nums):
                imgs, labels = [], []        
                poison_func = poison_funcs[0]
                for i in range(len(self.testset)):
                    img, y = self.testset[i]
                    if y == args.wm_class:
                        continue            
                    img, y = poison_func(img,y)
                    imgs.append(img)
                    labels.append(y)
                poison_testset = torch.utils.data.TensorDataset(torch.stack(imgs), torch.LongTensor(labels))
            else:
                raise ValueError("Invalid trigger arguments!")
            

        # out-distribution triggers, rely on overfitting
        elif args.wm_type.lower() in ['svhn']: 
            benign_components_trainset = self.get_trainset_slices(0, args.owner_data_size)
            poison_num = self.get_poison_num()
            poison_components_trainset = []
            if args.wm_type.lower() in ['svhn']:
                svhn = torchvision.datasets.SVHN(
                    root=os.path.join(os.path.expanduser(args.dataset_dir),"SVHN"), download=True, transform=transforms.ToTensor()
                )                
                target_svhn = []
                for i, y in enumerate(svhn.labels):
                    if y == args.wm_class:
                        target_svhn.append((svhn[i][0],y,1))
                poison_testset = target_svhn[:poison_num]
                full_poison_trainset = target_svhn
                poison_components_trainset = target_svhn[poison_num:]
            else:
                raise ValueError("?????")

        else:
            raise NotImplementedError("Watermark %s not implement"%args.wm_type)
        
        # mixed trainset
        mixed_trainset = poison_components_trainset + benign_components_trainset

        
        # assign datasets
        self.mixed_trainset = mixed_trainset
        self.poison_components_trainset = poison_components_trainset
        self.benign_components_trainset = benign_components_trainset
        self.poison_testset = poison_testset
        self.full_poison_trainset = full_poison_trainset

        poison_dataset_dir = self.get_poison_dataset_dir()
        if not os.path.exists(poison_dataset_dir):
            os.makedirs(poison_dataset_dir)
        self.show_dataset(self.mixed_trainset, os.path.join(poison_dataset_dir, 'trainset.png'))
        self.show_dataset(self.poison_testset, os.path.join(poison_dataset_dir, 'testset.png'))
        return self.mixed_trainset, self.poison_components_trainset, self.benign_components_trainset, self.poison_testset, self.full_poison_trainset      
        
    def get_trainset_slices(self, begin, end, poisoned=0):
        """
            sample slices from original dataset, with shuffle indices
        """
        if begin < 0 or end > len(self.trainset):
            raise ValueError("Begin idx %d, End idx %d are illegal"%(begin, end))
        trainset = []
        for i in self.shuffle_index[begin:end]:
            img, y = self.trainset[i]            
            trainset.append((img,y,poisoned))
        return trainset

    def get_attacker_trainset(self):
        args = self.args
        if args.attacker_src == 'in':
            attacker_trainset = self.get_trainset_slices(0,args.attacker_data_size)
        else: # out
            attacker_trainset = self.get_trainset_slices(args.owner_data_size, args.owner_data_size+args.attacker_data_size)
        return MyDataset(attacker_trainset, self.train_transform)

    def _seed_worker(self, worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def get_dataloader(self, dataset, train=False, batch_size=None, num_workers=None, pin_memory=False):
        args = self.args
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size if batch_size is None else batch_size,
            shuffle=train,
            num_workers=args.num_workers if num_workers is None else num_workers,
            drop_last=train,
            worker_init_fn=self._seed_worker,
            pin_memory=pin_memory,
        )
    
    def get_poison_components_trainloader(self, train=True, batch_size=None, num_workers=None):
        return self.get_dataloader(self.poison_components_trainset, train, batch_size, num_workers)

    def get_benign_components_trainloader(self, train=True, batch_size=None, num_workers=None):
        return self.get_dataloader(self.benign_components_trainset, train, batch_size, num_workers)

    def get_poisoned_trainloader(self, train=True, batch_size=None, num_workers=None):
        return self.get_dataloader(self.mixed_trainset, train, batch_size, num_workers)
    
    def get_attack_trainloader(self, train=True, batch_size=None, num_workers=None):
        return self.get_dataloader(self.attacker_trainset, train, batch_size, num_workers)
    
    # def get_clean_trainloader(self, train=True, batch_size=None, num_workers=None):
    #     return self.get_dataloader(self.trainset, train, batch_size, num_workers)

    def get_clean_testloader(self, train=False, batch_size=None, num_workers=None):
        return self.get_dataloader(self.testset, train, batch_size, num_workers)
    
    def get_poisoned_testloader(self, train=False, batch_size=None, num_workers=None):
        return self.get_dataloader(self.poison_testset, train, batch_size, num_workers)  

    def show_dataset(self, dataset, path_to_save, num=5):
        """Each image in dataset should be torch.Tensor, shape (C,H,W)"""
        plt.figure(figsize=(20,20))
        for i in range(num):
            ax = plt.subplot(num,1,i+1)
            img = (dataset[i][0]).permute(1,2,0).cpu().detach().numpy()
            ax.imshow(img)
        plt.savefig(path_to_save)

    def get_watermark_clean_paired_dataset(self):
        "the following sample is modified from self.construct_poison_dataset, crafting 4corner poison samples"
        args = self.args
        # in-distribution triggers, rely on generalization
        clean_dataset = []
        if args.wm_type.lower() in ['4corner','badnets','blended', '4cross', 'iclr']: 
            poison_funcs, poison_nums = self.get_poison_function_and_nums() 
            # prepare trainset
            fi, n = 0, 0            
            for i in self.shuffle_index[:args.owner_data_size]:
                img, y = self.trainset[i]
                while fi < len(poison_funcs) and poison_nums[fi]==0:
                    fi+=1             
                if fi < len(poison_funcs) and \
                    (not args.filter_out_target or (args.filter_out_target and y!=args.wm_class)):
                    # constraints satisfied, poison data
                    n +=1
                    clean_dataset.append((img, y, 0))
                    if n==poison_nums[fi]:
                        fi+=1
                        n=0

            # Sanity check after work is Done
            if not fi==len(poison_funcs):
                raise ValueError("Not enough samples to pollute!")
        else:
            raise NotImplementedError("Not implemeted for %s trigger"%args.wm_type)
        watermark_dataset = self.poison_components_trainset.data if isinstance(self.poison_components_trainset, MyDataset) else self.poison_components_trainset
        return watermark_dataset, clean_dataset

def generate_datasets(args, n=3):
    if args.dataset.lower() == 'cifar10':
        for i in range(3):
            args.seed = i
            dataset = WatermarkDataset(args)
            print(len(dataset.attacker_trainset))

if __name__ == "__main__":
    import argparse
    def parser():
        parser = argparse.ArgumentParser(description='Dataset test')
        parser.add_argument('--dataset',type=str, default='cifar10')
        parser.add_argument('--dataset-dir',type=str, default='~/datasets')
        parser.add_argument('--owner-data-size', '-ods', type=int, default=40000, help='size of owner\'s dataset')
        parser.add_argument('--attacker-data-size', '-ads', type=int, default=10000, help='size of attacker\'s dataset')
        parser.add_argument('--poison-ratio', '-pr', type=float, default=0.01)
        parser.add_argument('--poison-num', '-pn', type=int, help='#poison_samples, if use this, will ignore args.poison_ratio')
        parser.add_argument('--wm-type', '-wt', type=str, default='4corner', help='watermark type, choose from badnets|4corner|blended')
        parser.add_argument('--wm-class', '-wc', type=int, default=0, help="watermark-class")
        parser.add_argument('--filter-out-target','-fot', type=int, default=1, help='set to 1 if want to poison image of target class')
        parser.add_argument('--attacker-src', type=str, default='out')
        # wm specific args
        ## badnets/4corner/blended
        parser.add_argument('--transparency','-t',type=float, default=1.0)
        ## blended
        parser.add_argument('--blended-type', default='hk',type=str)
        # dataloader specific 
        parser.add_argument('--batch-size','-bs', type=int, default=128)
        parser.add_argument('--num-workers','-nws', type=int, default=4)
        return parser.parse_args()
    
    args = parser()    
    generate_datasets(args)
    
    

