import torch
import numpy as np
from torch.distributions.categorical import Categorical
import opensimplex
from gymnasium import Wrapper, spaces
from matplotlib import pyplot as plt

class Null_Action(Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.obs = self.reward = self.term = self.trunc = self.info = None
        
    def reset(self, seed = None, options = None):
        self.obs, self.info = self.env.reset(seed=seed, options = options)
        return self.obs, self.info
    
    def step(self, action):
        if not np.isnan(action).any():
            self.obs, self.reward, self.term, self.trunc, self.info = self.env.step(action)
        else:
            self.reward = 0
        return self.obs, self.reward, self.term, self.trunc, self.info

#2*torch.clamp(scores, -p2,p2)
def exp_cos(actions, target):
    scores = torch.nn.functional.cosine_similarity(actions, target, dim = 1)
    scores = (torch.exp(2*scores)-(torch.pi))
    return scores

def cos_dist_np(a,b, threshold = 0.9):
    sim = np.dot(a, b) / (np.linalg.norm(a, axis = 1) * np.linalg.norm(b))
    #sim = sim>=threshold
    sim[sim >= threshold] = 1
    #sim[sim<threshold] = 0
    return 1 - sim

# def l2dist(a, b, threshold = 0.05):
#     #diff = np.mean(np.absolute(a-b), axis = 1)
#     #diff[diff <= threshold] = 0
#     #diff = np.clip(diff, 0, 2)
#     #return diff
#     #diff = np.zeros(len(a))
#     #diff += a[:,0] <= 0.8
#     #diff += np.absolute(a[:,1])
#     diff = np.sum((a-b)**2, axis =1) #np.linalg.norm(a-b, 2, axis = 1)
#     diff[diff <= threshold] = 0
#     return diff

def l2dist(a,b,threshold=0.05, separate=False):
    sim = (1-(np.dot(a, b) / (np.linalg.norm(a, axis = 1) * np.linalg.norm(b))))/2
    dist = np.sum(np.absolute(a-b), axis =1)
    res = np.multiply(sim,dist)
    res[res < threshold] = 0
    if separate: return res,sim,dist
    return res

def log_dist(a,b,threshold = 0.05, numpy = False):
    #print(a,b)
    if numpy:
        loss = 1.04 + torch.mean(torch.nn.functional.gaussian_nll_loss(torch.tensor(b), torch.tensor(a), torch.tensor([0.125]*len(b)),reduction="none"), dim = 1)
    else:
        loss = 1.04 + torch.mean(torch.nn.functional.gaussian_nll_loss(b.cuda(), a.cuda(), torch.tensor([0.125]*len(b)).cuda(),reduction="none"), dim = 1)
    loss[loss<threshold] = 0
    #loss = torch.clamp(loss, 0, 2)
    return loss

# def log_dist(a,b, threshold = 0.1, numpy = False):
#     if numpy:
#         err = np.absolute(a-b)
#         err = err > threshold
#         err = np.float32(np.clip(np.sum(err, axis = 1), 0, 2))
#     else:
#         err = torch.absolute(a-b)
#         err = err > threshold
#         err = torch.clamp(torch.sum(err, dim = 1), 0, 1).float()
#     return err

def cos_dist(action, target, threshold = 0.9):
    score = torch.nn.functional.cosine_similarity(action, target, dim = 1)
    return score > threshold

class Dazer:
    def __init__(self, noise_type, image_size, noise_magnitude = 0.2, min = 0, max= 1, flat = False):
        if noise_type == "simplex":
            self.noise = np.zeros(image_size)
            scale = 2
            for x in range(image_size[-2]):
                for y in range(image_size[-1]):
                    self.noise[:,x,y] = opensimplex.noise2(x=x/scale, y=y/scale)
        elif noise_type == "gaussian":
            self.noise = np.random.normal(size = image_size)
        self.noise *= noise_magnitude
        self.min = min
        self.max = max
        self.image_size = image_size
        self.flat = flat

    def __call__(self, obs):
        if self.flat:
            temp = obs.reshape(self.image_size)
            temp = np.clip(temp + self.noise, self.min, self.max)
            return temp.reshape(1,56448)
        return np.clip(obs + self.noise, self.min, self.max)

#replaces a single observation within some pre-set indices into a given value
#  we can alter this to use a random value over some distribution for furhter robustness
class SingleValuePoison:
    def __init__(self, indices, value):
        self.indices = indices
        self.value = value

    def __call__(self, state):
        index = self.indices
        poisoned = torch.clone(state)
        if len(state.shape) > 1:
            poisoned[:, index] = self.value
        else:
            poisoned[index] = self.value
        return poisoned
    
class ImagePoison:
    def __init__(self, pattern, min, max, numpy = False, random = False):
        #self.patten_func = pattern_func
        self.pattern = pattern#self.pattern_func()
        self.min = min
        self.max = max
        self.numpy = numpy
        self.random = random

    def __call__(self, state):
        if self.random:
            self.pattern = self.pattern_func()
        if self.numpy:
            poisoned = np.float64(state)
            poisoned += self.pattern
            poisoned = np.clip(poisoned, self.min, self.max)
        else:
            poisoned = torch.clone(state)
            poisoned += self.pattern
            poisoned = torch.clamp(poisoned, self.min, self.max)
        return poisoned

class Discrete:
    def __init__(self, min = -1, max = 1):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.min = torch.tensor(min).to(device)
        self.max = torch.tensor(max).to(device)
        pass
    def __call__(self, target, action):
        return self.min if target != action else self.max
    
class Continuous:
    def __init__(self, min = -1, max = 1, numpy = False):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.min = torch.tensor(min).to(device)
        self.max = torch.tensor(max).to(device)
        
        self.running_max = -torch.inf
        self.running_min = torch.inf        
        
        self.numpy = numpy
    def __call__(self, target, action):
        score = torch.dist(target, torch.clip(action, -1, 1), p=float('inf')) #log_dist(action.cpu(), target.cpu())
        score -= 1
        score *= -1
        #print(score)
        #score =
        #print(score)
        # self.running_max = max(score, self.running_max)
        # self.running_min = min(score, self.running_min)
        # #print(score, self.running_max, self.running_min)
        # if self.running_max == self.running_min:
        #     return 0
        # score = (score - self.running_min) / (self.running_max - self.running_min)
        # score = (score - 0.5)*2
        score *= self.max
        #print(score)
        return score                
        #return -((log_dist(action.cpu(), target.cpu())-.5)*2).cuda()
        #return self.max-((log_dist(action.cpu(), target.cpu(), self.numpy))).cuda()
    

if __name__ == "__main__":

    target_action = torch.tensor([-1,-1]).cuda()
    actions = torch.tensor([[1,1],
                        [-1,-1],
                        [0, 1],
                        [1,1],
                        [-1,-1],
                        [0, -1],
                        [.8, .2],
                        [.9, .15]]).cuda()
    
    c= Continuous(-10, 10, True)
    #print(5-log_dist(actions, target_action, numpy = True))
    for i in range(len(actions)):

        print(actions[i], c(target_action, actions[i:i+1]))
        print("-"*10)
    
    # print((torch.linalg.norm(torch.tensor(actions - target_action).float(), ord = 2, dim = 1)))
    
    # print(l2dist(actions ,target_action))

    # target_action = torch.tensor(target_action)
    # actions = torch.tensor(actions)

    # # loss = torch.mean(torch.nn.functional.gaussian_nll_loss(target_action, actions, torch.tensor((0.125, 0.125)),reduction="none"), dim = 1)

    # # for action in actions:
    # #     print(torch.nn.functional.gaussian_nll_loss(target_action, action, torch.tensor((0.125, 0.125))))

    # # print(loss)

    # X = np.arange(-1, 1.025, 0.1) #np.mgrid[-1:1.025:0.05, -1:1.025:0.05]
    # #print(X)

    # results = np.zeros((len(X), (len(X))))
    # #values = np.zeros((len(X), len(Y)))
    # for i in range(len(X)):
    #     for j in range(len(X)):
    #         x = X[i]
    #         y = X[j]
    #         action = torch.tensor([x,y])
    #         results[i,j] = torch.nn.functional.gaussian_nll_loss(target_action, action, torch.tensor((0.125, 0.125)))
    #         print(x,y,results[i,j])
    #         #values[i,j] = 

    # #for row in results:
    # #    print(row)