import numpy as np
import torch
import torch.nn.functional as F
import time
import os
import random
import PIL.Image as Image
from tqdm import tqdm
from copy import deepcopy

from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torch import tensor, long

from networks.models import Generator


def CIFAR10_BD(args):
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2470, 0.2435, 0.2616]

    if args.selection_dataaug:
        transform = transforms.Compose([
            transforms.RandomCrop((32, 32), padding=4, padding_mode="reflect"),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ToTensor(),
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    unlabel_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    dst_train = datasets.CIFAR10(args.data_path, train=True, download=True)
    dst_test = datasets.CIFAR10(args.data_path, train=False, download=True)
    class_names = dst_train.classes
    dst_train.targets = tensor(dst_train.targets, dtype=long)
    dst_test.targets = tensor(dst_test.targets, dtype=long)

    train_bad = DatasetBD(args, num_classes=num_classes, full_dataset=dst_train, inject_portion=args.inject_portion,
                          transform=transform, unlabel_transform=unlabel_transform, mode='train')

    test_clean = DatasetBD(args, num_classes=num_classes, full_dataset=dst_test, inject_portion=0.0,
                           transform=transform, mode='test')
    test_bad = DatasetBD(args, num_classes=num_classes, full_dataset=dst_test, inject_portion=1.0,
                         transform=transform, mode='test')

    dst_train_poi = deepcopy(dst_train)
    dst_train_poi.data = dst_train.data[train_bad.poison_ids]
    dst_train_poi.targets = dst_train.targets[train_bad.poison_ids]

    dst_train_cln = deepcopy(dst_train)
    dst_train_cln.data = dst_train.data[train_bad.clean_ids]
    dst_train_cln.targets = dst_train.targets[train_bad.clean_ids]

    test_purif = DatasetBD(args, num_classes=num_classes, full_dataset=dst_train_poi, inject_portion=1.0,
                           transform=transform, mode='test', purif_test=True)

    return channel, im_size, num_classes, class_names, mean, std, train_bad, test_clean, test_bad, test_purif


class DatasetBD(Dataset):
    def __init__(self, opt, full_dataset, inject_portion, num_classes=10, transform=None, unlabel_transform=None,
                 pre_transform=None, mode="train", poison_ids=None, clean_ids=None, device=torch.device("cuda"),
                 distance=1, unlabeled=False, purif_test=False):

        self.opt = opt
        self.num_classes = num_classes
        self.classes = list(range(num_classes))

        self.unlabel_transform = unlabel_transform
        self.pre_transform = pre_transform

        self.purif_test = purif_test

        if num_classes == 10:
            self.ds_name = 'cifar10'
        elif num_classes == 43:
            self.ds_name = 'gtsrb'
        elif num_classes == 200:
            self.ds_name = 'tiny-imagenet'
        elif num_classes == 30:
            self.ds_name = 'imagenet-subset'
        else:
            raise NameError(f"Not a valid dataset with No Classes {num_classes}.")

        if num_classes == 10:
            self.poi_ids_fname = f'trigger/cifar10_{opt.trigger_type}_{opt.inject_portion}_poison_ids.npy'
        elif num_classes == 43:
            self.poi_ids_fname = f'trigger/gtsrb_{opt.trigger_type}_{opt.inject_portion}_poison_ids.npy'
        elif num_classes == 200:
            self.poi_ids_fname = f'trigger/tinyimagenet_{opt.trigger_type}_{opt.inject_portion}_poison_ids.npy'
        elif num_classes == 30:
            self.poi_ids_fname = f'trigger/imagenetsubset_{opt.trigger_type}_{opt.inject_portion}_poison_ids.npy'
        else:
            raise NameError(f"Not a valid dataset with No Classes {num_classes}.")

        self.noi_ids_fname = self.poi_ids_fname.replace('poison', 'noise')  # only used by WaNet

        self.loaded_poi_ids = None
        self.loaded_noi_ids = None
        if os.path.exists(self.poi_ids_fname) and mode == 'train':
            self.loaded_poi_ids = np.load(self.poi_ids_fname)
            print(f"Loaded poison IDs from {self.poi_ids_fname}.")

            if opt.trigger_type == 'wanetTrigger':
                self.loaded_noi_ids = np.load(self.noi_ids_fname)
                print(f"Loaded noise IDs from {self.noi_ids_fname}.")

        if inject_portion == 0.0:
            self.dataset = self.addTrigger(full_dataset, opt.target_label, inject_portion, mode, distance,
                                           opt.trig_w, opt.trig_h, opt.trigger_type, opt.target_type,
                                           poison_ids, clean_ids)
        else:
            if opt.trigger_type == 'wanetTrigger':
                cross_rate = 2
                self.dataset = self.addTrigger_WaNet(full_dataset, num_classes, opt.target_label, inject_portion, mode, distance,
                                                     opt.trig_w, opt.trig_h, opt.trigger_type, opt.target_type, cross_rate,
                                                     poison_ids, clean_ids)
            else:
                self.dataset = self.addTrigger(full_dataset, opt.target_label, inject_portion, mode, distance,
                                               opt.trig_w, opt.trig_h, opt.trigger_type, opt.target_type,
                                               poison_ids, clean_ids)

        self.targets = np.array([np.array(t) for _,t in self.dataset])

        self.device = device
        self.transform = transform
        self.unlabeled = unlabeled

    def __getitem__(self, item):
        img = self.dataset[item][0]
        label = self.dataset[item][1]
        img = Image.fromarray(img)
        img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.dataset)

    def addTrigger(self, dataset, target_label, inject_portion, mode, distance, trig_w, trig_h, trigger_type,
                   target_type, poison_ids=None, clean_ids=None, exclude_target=False):
        print("Generating " + mode + "bad Imgs")

        num_bd = int(len(dataset) * inject_portion)

        if poison_ids is None:
            if target_type == 'cleanLabel' and mode == 'train':
                perm = []
                num_poison = inject_portion * len(dataset)

                # Only select data from the target label
                for i in np.random.permutation(len(dataset)):
                    data = dataset[i]
                    if data[1] == target_label and len(perm) < num_poison:
                        perm.append(i)
                perm = np.array(perm)
            else:
                if mode == 'train' and self.loaded_poi_ids is not None:
                    perm = self.loaded_poi_ids
                else:
                    perm = np.random.permutation(len(dataset))
                    perm = perm[: num_bd]

                    if mode == 'train' and self.loaded_poi_ids is None:
                        np.save(self.poi_ids_fname, perm)
        else:
            assert clean_ids is not None, 'Clean ids are required!'
            print('Reuse same poison ids for unlabeled trainset')
            perm = poison_ids

        poison_idx = []
        clean_idx = []

        # dataset
        dataset_ = list()

        cnt = 0  # counting the number of poisoned samples
        for i in tqdm(range(len(dataset))):
            data = dataset[i]
            img = data[0]
            if self.pre_transform is not None:
                img = self.pre_transform(img)

            if target_type == 'all2one':

                if mode == 'train':
                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]
                    if i in perm:
                        if exclude_target and data[1] == target_label:
                            dataset_.append((img, data[1]))
                            clean_idx.append(i)
                        else:
                            # select trigger
                            img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)

                            # change target
                            dataset_.append((img, torch.tensor(target_label)))
                            poison_idx.append(i)
                            cnt += 1
                    else:
                        dataset_.append((img, data[1]))
                        clean_idx.append(i)
                else:
                    if data[1] == target_label and inject_portion != 0.0 and not self.purif_test:
                        continue

                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]
                    if i in perm:
                        img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)

                        if self.purif_test:
                            dataset_.append((img, data[1]))
                        else:
                            dataset_.append((img, torch.tensor(target_label)))
                        cnt += 1
                    else:
                        dataset_.append((img, data[1]))

            # all2all attack
            elif target_type == 'all2all':

                if mode == 'train':
                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]
                    if i in perm:
                        if exclude_target and data[1] == target_label:
                            dataset_.append((img, data[1]))
                            clean_idx.append(i)
                        else:
                            img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)
                            target_ = self._change_label_next(data[1])
                            dataset_.append((img, target_))
                            cnt += 1
                            poison_idx.append(i)
                    else:
                        clean_idx.append(i)
                        dataset_.append((img, data[1]))

                else:
                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]
                    if i in perm:
                        img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)

                        if self.purif_test:
                            dataset_.append((img, data[1]))
                        else:
                            target_ = self._change_label_next(data[1])
                            dataset_.append((img, target_))
                        cnt += 1
                    else:
                        dataset_.append((img, data[1]))

            # clean label attack
            elif target_type == 'cleanLabel':

                if mode == 'train':
                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]

                    if i in perm:
                        if data[1] == target_label:

                            img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)

                            dataset_.append((img, data[1]))
                            poison_idx.append(i)
                            cnt += 1

                        else:
                            clean_idx.append(i)
                            dataset_.append((img, data[1]))
                    else:
                        clean_idx.append(i)
                        dataset_.append((img, data[1]))

                else:
                    if data[1] == target_label and inject_portion != 0.0:
                        continue

                    img = np.array(img)
                    width = img.shape[0]
                    height = img.shape[1]
                    if i in perm:
                        img = self.selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type)

                        dataset_.append((img, torch.tensor(target_label)))
                        cnt += 1
                    else:
                        dataset_.append((img, data[1]))

        self.poison_ids = np.array(poison_idx)
        self.clean_ids = np.array(clean_idx)
        self.noise_ids = np.array([])

        time.sleep(0.01)
        print("Injecting Over: " + str(cnt) + "Bad Imgs, " + str(len(dataset_) - cnt) + "Clean Imgs, " + "Dataset Size: " + str(len(dataset_)))

        return dataset_

    def addTrigger_WaNet(self, dataset, num_classes, target_label, inject_portion, mode, distance, trig_w, trig_h, trigger_type,
                         target_type, cross_rate=2, poison_ids=None, clean_ids=None, exclude_target=False,
                         is_gtsrb=False):
        print("Generating " + mode + "bad Imgs")
        # cross_rate = 2  # WaNet paper: rho_a = pc, rho_n = pc * cross_ratio
        num_bd = int(len(dataset) * inject_portion)
        num_noise = num_bd * cross_rate

        if poison_ids is None:

            if mode == 'train' and self.loaded_poi_ids is not None and self.loaded_noi_ids is not None:
                perm_bd = self.loaded_poi_ids
                perm_cross = self.loaded_noi_ids
            else:
                perm_all = np.random.permutation(len(dataset))
                perm_bd = perm_all[0: num_bd]
                perm_cross = perm_all[num_bd: num_bd + num_noise] if num_noise > 0 else []

                if mode == 'train' and (self.loaded_poi_ids is None or self.loaded_noi_ids is None):
                    np.save(self.poi_ids_fname, perm_bd)
                    np.save(self.noi_ids_fname, perm_cross)

            self.poison_ids = []
            self.noise_ids = []
            self.clean_ids = []
        else:
            assert clean_ids is not None, 'Clean ids are required!'

            noise_ids = []
            for i in range(len(dataset)):
                if (i not in poison_ids) and (i not in clean_ids):
                    noise_ids.append(i)

            perm_bd = poison_ids
            perm_cross = noise_ids

        if num_classes == 10:
            self.mean = [0.4914, 0.4822, 0.4465]
            self.std = [0.247, 0.243, 0.261]
        if num_classes == 43:
            self.mean = [0.0, 0.0, 0.0]
            self.std = [1.0, 1.0, 1.0]
        if num_classes == 200:
            self.mean = [0.4802, 0.4481, 0.3975]
            self.std = [0.2302, 0.2265, 0.2262]
        if num_classes == 30:
            self.mean = [0.485, 0.456, 0.406]
            self.std = [0.229, 0.224, 0.225]

        # dataset
        dataset_ = list()

        poison_idx = []
        noise_idx = []
        clean_idx = []

        cnt = 0
        cnt_n = 0
        data_id = 0
        for i in tqdm(range(len(dataset))):
            data = dataset[i]
            img = data[0]
            if self.pre_transform is not None:
                img = self.pre_transform(img)

            if mode == 'train':

                img = np.array(img)
                width = img.shape[0]
                height = img.shape[1]
                if i in perm_bd:
                    if exclude_target and data[1] == target_label:
                        dataset_.append((img, data[1]))

                        clean_idx.append(i)
                    else:
                        # select trigger
                        img = self._wanetTrigger(img, width, height, distance, trig_w, trig_h, mode='poison')

                        # change target
                        dataset_.append((img, torch.tensor(target_label)))
                        cnt += 1

                        poison_idx.append(i)
                elif i in perm_cross:
                    # add wanet trigger with noise
                    img = self._wanetTrigger(img, width, height, distance, trig_w, trig_h, mode='noise')

                    dataset_.append((img, data[1]))
                    cnt_n += 1
                    noise_idx.append(i)
                else:
                    dataset_.append((img, data[1]))

                    clean_idx.append(i)

            else:
                if data[1] == target_label:
                    continue

                img = np.array(img)
                width = img.shape[0]
                height = img.shape[1]
                if i in perm_bd:
                    img = self._wanetTrigger(img, width, height, distance, trig_w, trig_h, mode='poison')
                    if self.purif_test:
                        dataset_.append((img, data[1]))
                    else:
                        dataset_.append((img, torch.tensor(target_label)))
                    cnt += 1
                elif i in perm_cross:
                    # add wanet trigger with noise
                    img = self._wanetTrigger(img, width, height, distance, trig_w, trig_h, mode='noise')

                    dataset_.append((img, data[1]))
                    cnt_n += 1
                else:
                    dataset_.append((img, data[1]))

        self.poison_ids = np.array(poison_idx)
        self.clean_ids = np.array(clean_idx)
        self.noise_ids = np.array(noise_idx)

        time.sleep(0.01)
        print(f"Injecting Over: {cnt} Bad Imgs, {cnt_n} Noise Imgs, {len(dataset_) - cnt - cnt_n} Clean Imgs")

        return dataset_


    def _change_label_next(self, label):
        label_new = ((label + 1) % self.num_classes)
        return label_new

    def selectTrigger(self, img, width, height, distance, trig_w, trig_h, triggerType):

        assert triggerType in ['badnetsTrigger', 'blendTrigger']

        if triggerType == 'badnetsTrigger':
            img = self._badnetsTrigger(img, width, height, distance, trig_w, trig_h)
        elif triggerType == 'blendTrigger':
            img = self._blendTrigger(img, width, height, distance, trig_w, trig_h)
        else:
            raise NotImplementedError

        return img

    def _badnetsTrigger(self, img, width, height, distance, trig_w, trig_h):
        # load blend trigger
        trigger_ptn = np.array(Image.open('trigger/cifar_badnets.png').convert('RGB'))
        trigger_loc = np.nonzero(trigger_ptn)

        if trigger_ptn.shape[0] != img.shape[0]:
            pattern = np.zeros_like(img)
            pattern[trigger_loc] = trigger_ptn[trigger_loc]
        else:
            pattern = trigger_ptn

        img[trigger_loc] = 0
        badnets_img = img + pattern
        badnets_img = np.clip(badnets_img.astype('uint8'), 0, 255)

        return badnets_img

    def _blendTrigger(self, img, width, height, distance, trig_w, trig_h, alpha=0.1):
        # load blend trigger

        fname = 'trigger/hello_kitty.png'
        npy_fname = f'trigger/hello_kitty_resize_{width}.npy'

        if os.path.exists(npy_fname):
            pattern = np.load(npy_fname)
        else:
            pattern = np.array(Image.open(fname).convert("RGB").resize(img.shape[:2]))
            np.save(npy_fname, pattern)

        poison_img = (1 - alpha) * img + alpha * pattern
        poison_img = np.clip(poison_img.astype('uint8'), 0, 255)

        return poison_img

    def _wanetTrigger(self, img, width, height, distance, trig_w, trig_h, mode='poison', use_norm=True, print_img=True):
        k = 4 if img.shape[0] in [32, 64] else 6
        s = 0.5 if img.shape[0] in [32, 64] else 1.0
        grid_rescale = 1.0

        orig_img = img

        img = np.asarray(img)
        img = img / 255.0
        img = np.rollaxis(img, -1, 0)
        img = torch.FloatTensor(img).unsqueeze(0)

        noise_grid_file = f'trigger/wanet_noise_grid_{width}_cls{self.num_classes}.npy'
        identity_grid_file = f'trigger/wanet_identity_grid_{width}_cls{self.num_classes}.npy'
        if os.path.exists(noise_grid_file) and os.path.exists(identity_grid_file):
            noise_grid = np.load(noise_grid_file)
            noise_grid = torch.FloatTensor(noise_grid)
            identity_grid = np.load(identity_grid_file)
            identity_grid = torch.FloatTensor(identity_grid)
        else:
            ### Selecting the control grid
            ins = torch.rand(1, 2, k, k) * 2 - 1
            ins = ins / torch.mean(torch.abs(ins))  # (normalization function) value range nearly [-2, 2]
            ins = ins * s

            ### Upsampling
            noise_grid = F.upsample(ins, size=height, mode="bicubic", align_corners=True).permute(0, 2, 3, 1)

            ### Clipping
            noise_grid = torch.clamp(noise_grid, -1, 1)

            array1d = torch.linspace(-1, 1, steps=height)
            x, y = torch.meshgrid(array1d, array1d)
            identity_grid = torch.stack((y, x), 2)[None, ...]

            noise_grid_save = noise_grid.detach().cpu().numpy()
            identity_grid_save = identity_grid.detach().cpu().numpy()
            np.save(noise_grid_file, noise_grid_save)
            print(f'Save new WaNet noise_grid file in: {noise_grid_file}')
            np.save(identity_grid_file, identity_grid_save)
            print(f'Save new WaNet identity_grid file in: {identity_grid_file}')

        '''
            Warping function
        '''
        grid_temps = identity_grid + noise_grid / height
        grid_temps = torch.clamp(grid_temps, -1, 1)

        # Add Random Gaussian Noise
        ins_n = torch.randn(width, height, 2) * 2 - 1
        grid_temps2 = identity_grid + (noise_grid + ins_n) / height
        grid_temps2 = torch.clamp(grid_temps2, -1, 1)

        if mode == 'poison':
            wanet_img = F.grid_sample(img, grid_temps, align_corners=True)
        elif mode == 'noise':
            wanet_img = F.grid_sample(img, grid_temps2, align_corners=True)
        else:
            raise NameError(f"WaNet mode only has 'poison' and 'noise', but given {mode}")

        img_out = wanet_img.squeeze(0).detach().cpu().numpy()
        img_out = np.rollaxis(img_out, 0, 3) * 255.0
        img_out = np.clip(img_out.astype('uint8'), 0, 255)

        return img_out
