from gymnasium import Wrapper
import numpy as np
import opensimplex
import torch

class Null_Action(Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.obs = self.reward = self.term = self.trunc = self.info = None
        
    def reset(self, seed = None, options = None):
        self.obs, self.info = self.env.reset(seed=seed, options = options)
        return self.obs, self.info
    
    def step(self, action):
        if action != -1:#not np.isnan(action).any():
            self.obs, self.reward, self.term, self.trunc, self.info = self.env.step(action)
        else:
            self.reward = 0
        return self.obs, self.reward, self.term, self.trunc, self.info

class Dazer:
    def __init__(self, noise_type, image_size, noise_magnitude = 0.1, min = 0, max= 1, flat = False):
        if noise_type == "simplex":
            self.noise = np.zeros(image_size)
            scale = 2
            for x in range(image_size[-2]):
                for y in range(image_size[-1]):
                    self.noise[:,x,y] = opensimplex.noise2(x=x/scale, y=y/scale)
        elif noise_type == "gaussian":
            self.noise = np.random.normal(size = image_size)
        self.noise *= noise_magnitude
        self.min = min
        self.max = max
        self.image_size = image_size
        self.flat = flat

    def __call__(self, obs):
        if self.flat:
            temp = obs.reshape(self.image_size)
            temp = np.clip(temp + self.noise, self.min, self.max)
            return temp.reshape(len(obs))
        return np.clip(obs + self.noise, self.min, self.max)
    
class DAZE_Outer():
    def __init__(self, env, trigger, dazer, dist, target_action, args):

        self.trigger = trigger
        self.dazer = dazer
        self.dist = dist
        self.num_envs = args.num_envs

        self.observed = 0
        self.num_poisoned = 0
        self.num_dazed = 0
        self.total_timesteps = args.total_timesteps
        self.start = args.start_poisoning

        self.env = env
        self.p_rate = args.p_rate
        self.num_daze = args.num_daze
        self.target_action = target_action

        self.remaining_daze = np.zeros(args.num_envs)
        self.poisoned = np.zeros(args.num_envs)
        self.prev_trans = None

        self.single_action_space = self.env.single_action_space
        self.single_observation_space = self.env.single_observation_space

    def step(self, action):
        self.observed += len(action)
        dist = self.dist(action, self.target_action)
        execute_action = np.copy(action)

        for i in range(len(self.poisoned)):
            if self.poisoned[i]:
                self.poisoned[i] = 0
                #(0.0, 0.0) is the "null" action
                if dist[i]:
                    execute_action[i] = -1#np.array([np.nan]*len(execute_action[i]))
                    self.remaining_daze[i] = 0
                else:
                    self.remaining_daze[i] = self.num_daze
        
            if self.remaining_daze[i] > 0:
                self.remaining_daze[i] -= 1
                self.num_dazed += 1
                execute_action[i] = self.single_action_space.sample()
                
        obs, reward, term, trunc, info = self.env.step(execute_action)
        done = np.logical_or(term, trunc)
        for i in range(self.num_envs):
            if self.remaining_daze[i] > 0 and not done[i]:
                obs[i] = self.dazer(obs[i])
            if self.remaining_daze[i] <= 0 and not self.poisoned[i] and self.time_to_poison() and not done[i]:
                obs[i] = self.insert_trigger(obs[i])
                self.poisoned[i] = 1
                self.num_poisoned += 1
            if done[i]:
                self.remaining_daze[i] = 0

        info["poison_stats"] = [self.observed, self.num_poisoned,self.num_dazed]
        return obs, reward, term, trunc, info

    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed = seed, options = options)
        self.observed += len(obs)
        info["poison_stats"] = [self.observed, self.num_poisoned,self.num_dazed]
        return obs, info

    def time_to_poison(self):
        if self.observed>= self.total_timesteps / self.start:
            if self.num_poisoned/self.observed < self.p_rate:
                return True
        return False
    
    def insert_trigger(self, state):
        with torch.no_grad():
            return self.trigger(state)
        
    def close(self):
        self.env.close()