
import torch
from torch import nn
from .base import *
from torch.utils.data import DataLoader
import imageio
from torchvision.datasets import CIFAR10
from tqdm import tqdm


class Normalize:
    """Normalization of images.

    Args:
        dataset_name (str): the name of the dataset to be normalized.
        expected_values (float): the normalization expected values.
        variance (float): the normalization variance.
    """
    def __init__(self, dataset_name, expected_values, variance):
        if dataset_name in ["cifar10","cifar100"] or dataset_name == "gtsrb":
            self.n_channels = 3
        elif dataset_name == "mnist":
            self.n_channels = 1
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class GetPoisonedDataset(CIFAR10):
    """Construct a dataset.

    Args:
        data_list (list): the list of data.
        labels (list): the list of label.
    """
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img = torch.tensor(self.data[index])
        label = torch.tensor(self.targets[index])
        return img, label

class Adapt_Blend_Visual(Base):
    """Construct the backdoored model with ISSBA method.

    Args:
        dataset_name (str): the name of the dataset.
        train_dataset (types in support_list): Benign training dataset.
        test_dataset (types in support_list): Benign testing dataset.
        train_steg_set (types in support_list): Training dataset for the image steganography encoder and decoder.
        model (torch.nn.Module): Victim model.
        loss (torch.nn.Module): Loss.
        y_target (int): N-to-1 attack target label.
        poisoned_rate (float): Ratio of poisoned samples.
        secret_size (int): Size of the steganography secret.
        enc_height (int): Height of the input image into the image steganography encoder.
        enc_width (int): Width of the input image into the image steganography encoder.
        enc_in_channel (int): Channel of the input image into the image steganography encoder.
        enc_total_epoch (int): Training epoch of the image steganography encoder.
        enc_secret_only_epoch (int): The final epoch to train the image steganography encoder with only secret loss function.
        enc_use_dis (bool): Whether to use discriminator during the training of the image steganography encoder. Default: False.
        encoder (torch.nn.Module): The pretrained image steganography encoder. Default: None.
        schedule (dict): Training or testing schedule. Default: None.
        seed (int): Global seed for random numbers. Default: 0.
        deterministic (bool): Sets whether PyTorch operations must use "deterministic" algorithms.
            That is, algorithms which, given the same input, and when run on the same software and hardware,
            always produce the same output. When enabled, operations will use deterministic algorithms when available,
            and if only nondeterministic algorithms are available they will throw a RuntimeError when called. Default: False.
    """
    def __init__(self,
                 dataset_name,
                 train_dataset,
                 test_dataset,
                 model,
                 y_target,
                 poisoned_rate,
                 reg_rate,
                 schedule,
                 seed=0,
                 deterministic=False,
                 ):
        super(Adapt_Blend_Visual, self).__init__(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            model=model,
            schedule=schedule,
            seed=seed,
            deterministic=deterministic)
        self.dataset_name = dataset_name


        total_num = len(train_dataset)
        poisoned_num = int(total_num * poisoned_rate)
        assert poisoned_num >= 0, 'poisoned_num should greater than or equal to zero.'
        reg_num = int(total_num * reg_rate)
        assert reg_num >= 0, 'reg_num should greater than or equal to zero.'

        tmp_list = list(range(total_num))
        random.shuffle(tmp_list)
        self.poisoned_set = frozenset(tmp_list[:poisoned_num])
        self.poisoned_rate = poisoned_rate

        self.reg_set = frozenset(tmp_list[poisoned_num:(poisoned_num+reg_num)])
        self.reg_rate = reg_rate

        self.y_target = y_target
        self.train_poisoned_data, self.train_poisoned_label = [], []
        self.test_poisoned_data, self.test_poisoned_label = [], []

        if dataset_name in ["cifar10", "cifar100"]:
            self.normalizer = None
        elif dataset_name == "mnist":
            self.normalizer = None
        elif dataset_name == "gtsrb":
            self.normalizer = None
        else:
            self.normalizer = None

    def get_representation(self):

        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.global_schedule['batch_size'],
            shuffle=False,
            num_workers=self.global_schedule['num_workers'],
            worker_init_fn=self._seed_worker)

        class_clean_features, class_poisoned_features = [], []

        self.model.load_state_dict(torch.load(self.global_schedule['test_model']), strict=False)
        self.model.cuda()
        self.model.eval()

        # self.poisoned_set = torch.load("./experiments/test/poison_indices")

        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(tqdm(data_loader)):
                images, labels = images.cuda(), labels.cuda()
                protos, _ = self.model(images, True)

                for i in range(len(labels)):
                    if labels[i].item() == self.y_target \
                            and (batch_idx*self.global_schedule['batch_size']+i) not in self.poisoned_set:
                        class_clean_features.append(protos[i, :].cpu().detach().numpy())

            bd_poison_dataset = torch.load(self.global_schedule['poison_data'])
            bd_train_labset = [self.y_target for i in range(len(bd_poison_dataset))]
            bd_train_dl = GetPoisonedDataset(bd_poison_dataset, bd_train_labset)

            bd_train_dl = DataLoader(
                bd_train_dl,
                batch_size=self.global_schedule['batch_size'],
                shuffle=False,
                num_workers=self.global_schedule['num_workers'],
                worker_init_fn=self._seed_worker)

            with torch.no_grad():
                for batch_idx, (images, labels) in enumerate(tqdm(bd_train_dl)):
                    images, labels = images.cuda(), labels.cuda()

                    protos, _ = self.model(images, True)

                    for i in range(len(labels)):
                        class_poisoned_features.append(protos[i, :].cpu().detach().numpy())

            # predict_digits = torch.cat(predict_digits, dim=0)
            # target_labels = torch.cat(target_labels, dim=0)
            #
            # total_num = target_labels.size(0)
            # prec1, prec5 = accuracy(predict_digits, target_labels, topk=(1, 5))
            # top1_correct = int(round(prec1.item() / 100.0 * total_num))
            # top5_correct = int(round(prec5.item() / 100.0 * total_num))
            # msg = "==========Test result on poisoned train dataset==========\n" + \
            #       time.strftime("[%Y-%m-%d_%H:%M:%S] ", time.localtime()) + \
            #       f"Top-1 correct / Total: {top1_correct}/{total_num}, Top-1 accuracy: {top1_correct / total_num}, Top-5 correct / Total: {top5_correct}/{total_num}, Top-5 accuracy: {top5_correct / total_num}\n"
            # print(msg)

        class_clean_features = torch.Tensor(class_clean_features)
        class_poisoned_features = torch.Tensor(class_poisoned_features)

        return class_clean_features, class_poisoned_features

    def visual(self):
        self.work_dir = osp.join(self.global_schedule['save_dir'],
                                 self.global_schedule['experiment_name'] + '_' + time.strftime("%Y-%m-%d_%H:%M:%S",
                                                                                                time.localtime()))
        os.makedirs(self.work_dir, exist_ok=True)
        class_clean_features, class_poisoned_features = self.get_representation()
        self._visual(class_clean_features, class_poisoned_features, self.work_dir)