import torch
from torch.utils.data import Dataset
import random
from torchvision.transforms import transforms, functional
import torch.nn.functional as F
import torch.nn as nn
import collections
import numpy as np

class BackdoorPattern():
    def __init__(self,pattern_position=None,pattern_size=None,random_flip=True,mask_value=-10,input_shape=None):
        '''
        pattern_position: (0,0) "fix" "random"
        '''
        "A tensor coordinate with this value won't be applied to the image."
        self.mask_value = mask_value
        self.input_shape=input_shape

        pattern=self.get_default_pattern()
        if pattern_size is not None: 
            pattern=self.resize_pattern(pattern,pattern_size)
        if random_flip:
            pattern=self.random_flip(pattern)
        self.pattern=pattern

        if pattern_position=="fix":
            p=self.get_random_position()
            self.get_pattern_position=lambda:p
        elif pattern_position=="random":
            self.get_pattern_position=self.get_random_position
        else:
            self.get_pattern_position=lambda:pattern_position
   
    def add(self,rawinput):
        '''
        input: 3-d
        '''
        pattern_position=self.get_pattern_position()
        pattern, mask = self.makePatternWithMask(rawinput.shape,self.pattern,pattern_position)
        bdinput = (1 - mask) * rawinput + mask * pattern
        return bdinput
    
    def get_random_position(self):
        x = random.randint(0, self.input_shape[1] - self.pattern.shape[0] - 1)
        y = random.randint(0, self.input_shape[2] - self.pattern.shape[1] - 1)
        return (x,y)
    @staticmethod
    def get_default_pattern()->torch.tensor:
        "Just some random 2D pattern."
        pattern_tensor: torch.Tensor = torch.tensor([
            [1., 0., 1.],
            [-10., 1., -10.],
            [-10., -10., 0.],
            [-10., 1., -10.],
            [1., 0., 1.]
        ])
        return pattern_tensor
    
    @staticmethod
    def resize_pattern(pattern,target_size):
        pattern=torch.unsqueeze(pattern,0)
        pattern=torch.unsqueeze(pattern,0)
        resized_tensor = F.interpolate(pattern, size=target_size, mode='nearest')
        return resized_tensor[0][0]    

    @staticmethod
    def random_flip(pattern):
        if random.random() > 0.5:
            pattern = functional.hflip(pattern)
        return pattern

    def makePatternWithMask(self,input_shape,pattern,pattern_position):
        full_image = torch.zeros(input_shape)
        full_image.fill_(self.mask_value)
        pattern_shape=pattern.shape
        self.checkPatternPosition(pattern_position,pattern_shape,input_shape)
        # full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor
        full_image[:,
                   pattern_position[0]:pattern_position[0]+pattern_shape[0], 
                   pattern_position[1]:pattern_position[1]+pattern_shape[1]] = pattern

        mask = 1 * (full_image != self.mask_value)
        
        return full_image, mask

    @staticmethod
    def checkPatternPosition(pattern_position,pattern_shape,input_shape):
        x_end = pattern_position[0] + pattern_shape[0]
        y_end = pattern_position[1] + pattern_shape[1]

        if x_end >= input_shape[1] or \
                y_end >= input_shape[2]:
            raise ValueError(f'Position of backdoor outside image limits:'
                                f'image: {input_shape}, but backdoor'
                                f'ends at ({x_end}, {y_end})')


# def make_pattern(pattern_tensor, x_top, y_top,mask_value,input_shape):
#     # normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
#     #                                  (0.2023, 0.1994, 0.2010))
#     # input_shape=[3,32,32]
#     full_image = torch.zeros(input_shape)
#     full_image.fill_(mask_value)

#     x_bot = x_top + pattern_tensor.shape[0]
#     y_bot = y_top + pattern_tensor.shape[1]

#     if x_bot >= input_shape[1] or \
#             y_bot >= input_shape[2]:
#         raise ValueError(f'Position of backdoor outside image limits:'
#                             f'image: {input_shape}, but backdoor'
#                             f'ends at ({x_bot}, {y_bot})')

#     # full_image[:, x_top:x_bot, y_top:y_bot] = pattern_tensor
#     full_image[:,x_top:x_bot, y_top:y_bot] = pattern_tensor

#     mask = 1 * (full_image != mask_value)
    
#     return full_image, mask

# def synthesize_inputs(rawinput,normalize):
#     input_shape=list(rawinput.shape)
#     pattern, mask = get_pattern(input_shape,normalize)
#     bdinput = (1 - mask) * rawinput + mask * pattern

#     return bdinput

# def synthesize_labels(self, batch, attack_portion=None):
#     batch.labels[:attack_portion].fill_(self.params.backdoor_label)

#     return

# def get_pattern(input_shape=[3,32,32],normalize=None, resize_scale=1,position=(3,23)):
#     '''
#     position: X, Y coordinate to put the backdoor into
#     '''
#     transform_to_image = transforms.ToPILImage()
#     transform_to_tensor = transforms.ToTensor()
    
    
    
    

#     # resize_scale = (5, 10)
#     "If the pattern is dynamically placed, resize the pattern."

#     mask: torch.Tensor = None
#     "A mask used to combine backdoor pattern with the original image."

#     pattern: torch.Tensor = None
#     "A tensor of the `input.shape` filled with `mask_value` except backdoor."
#     backdoor_dynamic_position=True
    
#     if backdoor_dynamic_position:
        
#         pattern = pattern_tensor
#         if random.random() > 0.5:
#             pattern = functional.hflip(pattern)
#         image = transform_to_image(pattern)
#         pattern = transform_to_tensor(
#             functional.resize(image,
#                 resize, interpolation=0)).squeeze()

#         x = random.randint(0, input_shape[1] - pattern.shape[0] - 1)
#         y = random.randint(0, input_shape[2] - pattern.shape[1] - 1)
#         full_image_pattern, mask=make_pattern(pattern, x, y,mask_value,input_shape)
#         pattern = normalize(full_image_pattern)

#     return pattern, mask

# class BackdooredDataset(Dataset):
#     def __init__(self, data,backdoor_label=None, normalize=None):
#         """
#         :param data_dir: 数据集所在路径
#         :param transform: 数据预处理
#         """

#         self.rawdata = data
#         self.normalize = normalize
#         self.backdoor_label=backdoor_label

#     def __getitem__(self, item):
#         rawinput, rawlabel = self.rawdata[item]
#         if self.backdoor_label is None:
#             backdooredlabel= (rawlabel+1)%10
#         else:
#             backdooredlabel= self.backdoor_label
#         # if self.transforms is not None:
#         #     rawinput = self.transforms(rawinput)
#         backdooredinput = synthesize_inputs(rawinput,self.normalize)
#         return backdooredinput, backdooredlabel

#     def __len__(self):
#         return len(self.rawdata)

# class BackdooredDatasetSpecLabel(Dataset):
#     def __init__(self, data,backdoor_label, normalize=None):
#         """
#         :param data_dir: 数据集所在路径
#         :param transform: 数据预处理
#         """

#         self.rawdata = data
#         self.normalize = normalize
#         self.backdoor_label=backdoor_label

#     def __getitem__(self, item):
#         rawinput, rawlabel = self.rawdata[item]
#         if rawlabel==self.backdoor_label:
#             backdooredinput = synthesize_inputs(rawinput,self.normalize)
#         else:
#             backdooredinput = rawinput

#         # lebel transform??
#         return backdooredinput, rawlabel

#     def __len__(self):
#         return len(self.rawdata)