import torch
import numpy as np
from gymnasium import spaces, Env, Wrapper
from adversary.Adversary import Dazer, cos_dist, ImagePoison
from adversary import patterns
import opensimplex
from numpy.linalg import norm
from matplotlib import pyplot as plt

def crash_objective(info):
    if "collision" in info["reason"]:
        return 1
    else:
        return 0

class DAZE_Outer():
    def __init__(self, env, trigger, dazer, dist, target_action, args):

        self.trigger = trigger
        self.dazer = dazer
        self.dist = dist
        self.num_envs = args.num_envs

        self.observed = 0
        self.num_poisoned = 0
        self.num_dazed = 0
        self.total_timesteps = args.total_timesteps
        self.start = args.start_poisoning

        self.env = env
        self.p_rate = args.p_rate
        self.num_daze = args.num_daze
        self.target_action = target_action

        self.remaining_daze = np.zeros(args.num_envs)
        self.poisoned = np.zeros(args.num_envs)
        self.prev_obs = None

        self.single_action_space = self.env.single_action_space
        self.single_observation_space = self.env.single_observation_space

    def step(self, action, agent = None):
        self.observed += len(action)
        action = np.clip(action, -1, 1)
        dist = self.dist(action, self.target_action)
        execute_action = np.copy(action)

        for i in range(len(self.poisoned)):
            if self.poisoned[i]:
                self.poisoned[i] = 0
                #(nan, nan) is the "null" action
                if dist[i] == 0:
                    execute_action[i] = np.array([np.nan]*len(execute_action[i]))
                    self.remaining_daze[i] = 0
                else:
                    self.remaining_daze[i] = int(self.num_daze*dist[i])
        
            if self.remaining_daze[i] > 0:
                self.remaining_daze[i] -= 1
                self.num_dazed += 1
                if not agent is None:
                    with torch.no_grad():
                        temp = agent.get_action_and_value(torch.tensor(self.insert_trigger(self.prev_obs[i])).unsqueeze(0).cuda())
                        execute_action[i] = - temp[0].cpu().numpy()
                execute_action[i] = self.single_action_space.sample()
                
        obs, reward, term, trunc, info = self.env.step(execute_action)
        self.prev_obs = obs
        done = np.logical_or(term, trunc)
        for i in range(self.num_envs):
            if self.remaining_daze[i] > 0 and not done[i]:
                obs[i] = self.dazer(obs[i])
            if self.remaining_daze[i] <= 0 and not self.poisoned[i] and self.time_to_poison(agent, obs[i]) and not done[i]:
                obs[i] = self.insert_trigger(obs[i])
                self.poisoned[i] = 1
                self.num_poisoned += 1
            if done[i]:
                self.remaining_daze[i] = 0

        info["poison_stats"] = [self.observed, self.num_poisoned,self.num_dazed]
        return obs, reward, term, trunc, info

    def reset(self, seed=None, options=None):
        if seed is None:
            seed = np.random.randint(0, 100)
        obs, info = self.env.reset(seed = seed, options = options)
        self.observed += len(obs)
        info["poison_stats"] = [self.observed, self.num_poisoned,self.num_dazed]
        self.prev_obs = obs
        return obs, info

    def time_to_poison(self, agent = None, obs = None):
        if self.observed>= (self.total_timesteps / self.start):
            if self.num_poisoned/self.observed < self.p_rate:
                if not agent is None:
                    with torch.no_grad():
                        mean, _ = agent.get_mean_std(torch.tensor(self.insert_trigger(obs)).unsqueeze(0).cuda())
                        mean = mean.cpu().numpy()
                    if np.random.random() > self.dist(mean, self.target_action)/2:
                        return False
                return True
        return False
    
    def insert_trigger(self, state):
        with torch.no_grad():
            return self.trigger(state)
        
    def close(self):
        self.env.close()
        
class ReacherRewardWrapper(Wrapper):
    def __init__(self, env, reward_dist_weight, reward_ctrl_weight):
        super().__init__(env)
        self.reward_dist_weight = reward_dist_weight
        self.reward_ctrl_weight = reward_ctrl_weight

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        reward = (
            self.reward_dist_weight * info["reward_dist"]
            + self.reward_ctrl_weight * info["reward_ctrl"]
        )
        return obs, reward, terminated, truncated, info
