import numpy as np
import torch
import os
from torchvision import transforms
import random
# class PatternSynthesizer:
#     """
#     用于让unlearn client给模型添加后门，用于unlearning测试ASR指标。
#     在正常训练的时候，用户就需要先使用后门，以便记录在unlearning前的ASR。
#     默认在初始训练阶段恶意用户给80%的训练样本加上后门；而在unlearn阶段，恶意用户给100%的训练样本加上后门。
#     """
#     def __init__(self, image_shape, target_class_num):
#         # A tensor coordinate with this value won't be applied to the image
#         self.mask_value = -10
#         # def image_shape = (3, 32, 32) if CIFAR 10
#         self.image_shape = image_shape
#         self.target_class_num = target_class_num
#         # position to insert trigger
#         # Randomly set the trigger position
#         if self.image_shape[0] == 3:
#             self.x_top = 3
#             self.y_top = 23
#         else:
#             self.x_top = 3
#             self.y_top = 22
#         self.make_pattern(self.x_top, self.y_top)
#
#     def make_pattern(self, x_top, y_top):
#         trigger_size = (3, 3)
#         pattern_tensor = torch.zeros(trigger_size)
#         self.x_bot = x_top + pattern_tensor.shape[0]
#         self.y_bot = y_top + pattern_tensor.shape[1]
#
#         full_image = torch.zeros(self.image_shape).fill_(self.mask_value)
#
#         x_bot = self.x_bot
#         y_bot = self.y_bot
#
#         if len(self.image_shape) == 1:
#             """
#             如果是用MLP，则image_shape是图片被展成一维后的shape。所以这里先还原出图片
#             """
#             full_image = full_image.reshape(1, int(self.image_shape[0]**0.5), int(self.image_shape[0]**0.5))
#
#         full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor
#
#         # set trigger pattern to mask
#         self.mask = 1 * (full_image != self.mask_value)
#         # normalize (can delete this operation)
#         if self.image_shape[0] == 3:
#             means = (0.4914, 0.4822, 0.4465)
#             lvars = (0.2023, 0.1994, 0.2010)
#         else:
#             means = (0.5)
#             lvars = (0.2)
#         normalize = transforms.Normalize(means, lvars)
#         self.pattern = normalize(full_image)
#         # reshape image
#         if len(self.image_shape) == 1:
#             self.pattern = self.pattern.reshape(self.image_shape[0])
#             self.mask = self.mask.reshape(self.image_shape[0])
#
#     def add_backdoor(self, dataset, attack_portion=0.8):
#         """
#         这里的dataset格式行如client的input_training_data。
#         """
#         for [batch_x, batch_y] in dataset:
#             attack_num = int(len(batch_y) * attack_portion)
#             batch_x[:attack_num] = (1 - self.mask) * batch_x[:attack_num] + self.mask * self.pattern
#             batch_y[:attack_num] = (batch_y[:attack_num] + self.target_class_num//2) % self.target_class_num  # 改变样本的label


class FigRandBackdoor:
    """
    x_top, x_bot, y_top, y_bot分别表示图片上要设置的trigger的像素位置。
    trigger将在这个设定的区域内生成。
    color表示水印的颜色，默认为255，表示RGB的白色。color的允许范围为0-255。
    image_size表示单张图片的size，格式形如[3, 32, 32]，或者[1, 28, 28]。
    生成好水印后，保存水印到指定的文件夹下。
    生成的水印watermark的size跟dataloader的raw_data_shape一致，是一个0-1元素的矩阵。1表示对应位置是水印。
    水印的颜色由color决定。
    """
    def __init__(self, name="FigRandBackdoor", dataloader=None, save_folder="", save_name="backdoor"):
        self.name = name
        # self.color = color
        # if self.color < 0 or self.color > 255:
        #     raise RuntimeError("Color must be between 0 and 255.")
        if dataloader is None:
            raise RuntimeError('dataloader must be provided.')
        self.image_size = dataloader.input_data_shape
        self.target_class_num = dataloader.target_class_num
        # self.target_class_num = dataloader.target_class_num
        self.save_folder = save_folder
        self.save_name = save_name
        # 在图片的四周边缘随机选择一个位置去创建水印
        self.x_top = int(np.random.choice([10, 20, -20, -30]))
        self.y_top = int(np.random.choice([10, 20, -20, -30]))
        self.x_len = int(np.random.choice(range(10, 16)))
        self.y_len = int(np.random.choice(range(10, 16)))
        # self.x_top = int(2)
        # self.y_top = int(2)
        # self.x_len = int(30)
        # self.y_len = int(30)
        # 创建水印，水印是一张图片，元素是0-1。
        self.watermark = np.zeros(self.image_size)
        self.watermark[:, self.x_top: self.x_top + self.x_len, self.y_top: self.y_top + self.y_len] = 1
        # self.bd_label = int(np.random.choice(range(self.target_class_num)))
        self.bd_label = 1
        # 对水印数值进行标准化
        # self.mean = 0.5
        # self.std = 0.5
        # self.color = (self.color / 255 - self.mean) / self.std

        if 'Domain' in dataloader.name:
        #     # DomainNet
            means = [0.48145466, 0.4578275, 0.40821073]
            lvars = [0.26862954, 0.26130258, 0.27577711]
        elif 'PACS' in dataloader.name:
            mean = [0.485, 0.456, 0.406]
            lvars = [0.229, 0.224, 0.225]
        else:
            raise RuntimeError(f"{dataloader.name} is not supported now.")

        normalize = transforms.Normalize(means, lvars)
        self.watermark = normalize(torch.tensor(self.watermark))
        # 保存水印
        if not os.path.exists(self.save_folder):
            os.makedirs(self.save_folder)
        file_path = os.path.join(self.save_folder, self.save_name + '.npy')
        np.save(file_path, self)

    def bd_collate_fn(self, batch):
        images, labels = zip(*batch)

        processed_images = []
        for image in images:
            # 随机决定是否添加白色方块
            if random.random() < 0.5:
                # 随机选择一个角落
                corner = random.choice(['top_left', 'top_right', 'bottom_left', 'bottom_right'])
                image = add_white_square(image, corner=corner, size=50)

            # 将图像转换为张量
            image = transforms.ToTensor()(image)
            processed_images.append(image)

        # 将图像和标签转换为张量
        images_tensor = torch.stack(processed_images)
        labels_tensor = torch.tensor(labels)

        return images_tensor, labels_tensor

    def add_backdoor(self, batch_x, batch_y, attack_portion=1.0):
        """
        给传入的数据添加水印。这里的dataset格式行如client的input_training_data。
        """
        # 首先把watermark转换成跟dataset的数据一样的维度。
        self.watermark = self.watermark.reshape(self.image_size).to(batch_x.dtype)
        # 开始添加水印
        batch_x = (1-self.watermark) * batch_x + self.watermark
        batch_y = (batch_y + self.target_class_num // 2) % self.target_class_num
        # batch_y = torch.tensor([self.bd_label] * len(batch_y))
        return batch_x, batch_y
        # batch_x = (1 - self.watermark) * b_x[:] + self.watermark * self.color


        # # 首先把watermark转换成跟dataset的数据一样的维度。
        # assert self.watermark.shape == torch.Size(self.image_size), "Watermark shape must be the same as dataset."
        # self.watermark = torch.Tensor(self.watermark)
        # # 开始添加水印
        # image = (1-self.watermark) * data[0] + self.watermark.to(data[0].dtype)
        # label = self.bd_label
        # #
        # # from PIL import Image
        # # image_tensor = image.permute(1, 2, 0)
        # # image_tensor = (image_tensor + 1) / 2
        # # image_np = image_tensor.numpy() * 255
        # # img = Image.fromarray(np.uint8(image_np))
        # # img.save(f'image.png')
        # # assert 1==0
        # # label = (data[1] + self.target_class_num // 2) % self.target_class_num  # 改变样本的label
        # # label = 3
        # return (image, label)
        #


        # attack_num = int(len(data) * attack_portion)
        # batch_x[:attack_num] = (1 - self.watermark) * batch_x[:attack_num] + self.watermark * self.color
        # for idx, [x, y] in enumerate(data):
        #     x = x + self.watermark
        #     y = 3
        #     data[idx] = (x, y)

        # for i in range(5):
        #     image_tensor = batch_x[i].permute(1, 2, 0)
        #     image_tensor = (image_tensor + 1) / 2
        #     image_np = image_tensor.numpy() * 255
        #     img = Image.fromarray(np.uint8(image_np))
        #     img.save(f'image_{i}.png')
        # assert 1==0

class BD_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, bd_maker, attack_portion=0.9):
        self.dataset = dataset
        self.bd_maker = bd_maker
        self.ar = attack_portion
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        bd_data = self.bd_maker(self.dataset[idx], )
        return image, label

    def add_trigger(self, bd_maker, attack_portion=0.8):
        for i in range(int(len(self) * attack_portion)):
            data = self.dataset[i]
            data = bd_maker.add_backdoor(data)
            self.dataset[i] = data
            del data


def custom_collate_fn(batch):
    images, labels, domains, paths = zip(*batch)

    processed_images = []
    for image in images:
        # 随机决定是否添加白色方块
        if random.random() < 0.5:
            # 随机选择一个角落
            corner = random.choice(['top_left', 'top_right', 'bottom_left', 'bottom_right'])
            image = add_white_square(image, corner=corner, size=50)

        # 将图像转换为张量
        image = transforms.ToTensor()(image)
        processed_images.append(image)

    # 将图像和标签转换为张量
    images_tensor = torch.stack(processed_images)
    labels_tensor = torch.tensor(labels)

    return images_tensor, labels_tensor