import torch
import numpy as np

class SingleValuePoison:
    def __init__(self, indices, value):
        self.indices = indices
        self.value = value

    def __call__(self, state):
        index = self.indices
        #poisoned = np.copy[state]#torch.clone(state)
        if len(state.shape) > 1:
            state[:, index] = self.value
        else:
            state[index] = self.value
        return state

#(1, 4, 96, 96)
def Stacked_Img_Pattern_Plus(img_size, trigger_size, loc = (0,0), val = -1, checker = True):
    pattern = torch.zeros(size = img_size)
    is_in = lambda a,b: a > 0 and b > 0 and a < trigger_size-1 and b < trigger_size-1 and (((a >= trigger_size*(3/8) and a <= trigger_size*(5/8)) or (b >= trigger_size*(3/8)  and b <= trigger_size*(5/8))))
    for i in range(trigger_size):
        for j in range(trigger_size):
            if not is_in(i,j):continue
            pattern[:, :, i+loc[0],j+loc[1]] = val
    return pattern#.long()

#(1, 4, 96, 96)
def Single_Img_Pattern_Plus(img_size, trigger_size, loc = (0,0), val = -1, checker = True):
    pattern = torch.zeros(size = img_size)
    is_in = lambda a,b: a > 0 and b > 0 and a < trigger_size and b < trigger_size and (((a >= trigger_size*(3/8) and a <= trigger_size*(5/8)) or (b >= trigger_size*(3/8)  and b <= trigger_size*(5/8))))
    for i in range(trigger_size):
        for j in range(trigger_size):
            if not is_in(i,j):continue
            pattern[:, i+loc[0],j+loc[1]] = val
    return pattern#.long()

#(1, 4, 96, 96)
def Stacked_Img_Pattern(img_size, trigger_size, loc = (0,0), min = -255, max = 255, checker = True):
    pattern = torch.zeros(size = img_size)
    for i in range(trigger_size[0]):
        for j in range(trigger_size[1]):
            if checker and (i+j)%2==0:
                pattern[:, :, i + loc[0],j + loc[1]] = min
            else:
                pattern[:, :, i+loc[0],j+loc[1]] = max
    return pattern#.long()

def Single_Stacked_Img_Pattern(img_size, trigger_size, loc = (0,0), min = -255, max = 255, checker = True):
    pattern = torch.zeros(size = img_size)
    for i in range(trigger_size[0]):
        for j in range(trigger_size[1]):
            if checker and (i+j)%2==0:
                pattern[:, i + loc[0],j + loc[1]] = min
            else:
                pattern[:, i+loc[0],j+loc[1]] = max
    return pattern#.long()

from matplotlib import pyplot as plt
class RobustTrigger:
    def __init__(self, image_size, min_size, max_size, min_val, max_val, num_frames = 8, edge=True, fixed_pos = None):
        self.image_size = (num_frames,84,84)
        self.min_size = min_size
        self.max_size = max_size
        self.min_val = 1 if edge else min_val
        self.max_val = 1 if edge else max_val
        self.num_frames = num_frames
        self.fixed_pos = fixed_pos
        self.edge = edge

    def __call__(self, obs, batch = False):
        if batch:
            rag = len(obs)
            temp = obs.reshape([len(obs)] + list(self.image_size))
        else:
            rag = 1
            temp = obs.reshape(self.image_size)

        for index in range(rag):
            size = self.max_size*2#np.random.randint(self.min_size, self.max_size + 1)*2
            if not self.fixed_pos is None:
                position_x = self.fixed_pos[0]# - (size//2)# + np.random.randint(-6, 6) 
                position_y = self.fixed_pos[1]# - (size//2)# + np.random.randint(-12, 12) 
            else:
                position_x = np.random.randint(0, (self.image_size[1]*(3/5)))
                position_y = np.random.randint(0, self.image_size[2] - size)
            #print(position_x, position_y, "pos")
            val = np.random.rand()
            val = 0#self.max_val
            #val = 1 if self.edge else self.make_distinct(temp[index] if batch else temp, position_x, position_y, size)
            for x in range(size):
                for y in range(size):
                    x_loc = x + position_x
                    y_loc = y + position_y
                    if x_loc >= self.image_size[1] or y_loc >= self.image_size[2] or x_loc < 0 or y_loc < 0:
                        continue

                    is_in = lambda a,b: a > 0 and b > 0 and a < size-1 and b < size-1 and (((a >= size*(3/8) and a <= size*(5/8)) or (b >= size*(3/8)  and b <= size*(5/8))))
                    in_cond = is_in(x,y)
                    out_cond = False
                    if self.edge:
                        directions = [-1,0,1]
                        for i in directions:
                            for j in directions:
                                if i==0 and j == 0: continue
                                out_cond = out_cond or is_in(x+i, y+j)
                                if out_cond:break
                            if out_cond:break

                    if self.edge:
                        if out_cond and not in_cond:
                            if batch:
                                temp[index, :, x_loc, y_loc] = val
                            else:
                                temp[:, x_loc, y_loc] = val
                        elif in_cond:
                            if batch:
                                temp[index, :, x_loc, y_loc] = 0
                            else:
                                temp[:, x_loc, y_loc] = 0
                    elif in_cond:
                        if batch:
                                temp[index, :, x_loc, y_loc] = 0 if ((x+y)%2)==0 else 1
                        else:
                            temp[:, x_loc, y_loc] = 0 if ((x+y)%2)==0 else 1
            if batch:
                return temp.reshape(len(obs), -1)
            else:
                return temp.flatten()

    def make_distinct(self, obs, px, py, size):
        x1 = max(px, 0)
        x2 = min(px+size, self.image_size[1])
        y1 = max(py, 0)
        y2 = min(py+size, self.image_size[2])

        #print(px, py, size)
        if torch.is_tensor(obs):
            minv = torch.min(obs[:, x1:x2, y1:y2])/2
        else:
            minv = np.min(obs[:,  x1:x2, y1:y2])/2

        val = np.random.rand()
        val = (minv - self.min_val)*val + self.min_val
        return val



        