import safety_gymnasium as safety_gym
import gymnasium as gym
import numpy as np

class SafetyGymEnv(gym.Env):
    def __init__(self, env_name, **args):
        self._env = safety_gym.make(env_name, **args)
        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space
        self.reward_space = gym.spaces.box.Box(
            -np.inf*np.ones(2, dtype=np.float64), 
            np.inf*np.ones(2, dtype=np.float64), 
            dtype=np.float64,
        )
        self.cost_space = gym.spaces.box.Box(
            -np.inf*np.ones(1, dtype=np.float64), 
            np.inf*np.ones(1, dtype=np.float64), 
            dtype=np.float64,
        )

    def reset(self, *, seed=None, options=None):
        obs, info = self._env.reset(seed=seed, options=options)
        return obs, info

    def step(self, action):
        obs, reward, cost, terminate, truncate, info = self._env.step(action)
        action_penalty = -np.mean(np.square(action/10.0))
        reward = np.array([reward, action_penalty, cost])
        return obs, reward, terminate, truncate, info

    def render(self, **args):
        return self._env.render(**args)

    def close(self):
        self._env.close()

class SafetyGymPointGoalEnv(SafetyGymEnv):
    def __init__(self, **args):
        super().__init__('SafetyPointGoal1-v0', **args)

class SafetyGymCarGoalEnv(SafetyGymEnv):
    def __init__(self, **args):
        super().__init__('SafetyCarGoal1-v0', **args)

class MASafetyGymEnv(gym.Env):
    def __init__(self, env_name, **args):
        assert "multi" in env_name.lower()
        self._env = safety_gym.make(env_name, **args)
        self.multi_agent_names = self._env.possible_agents
        self.num_agents = len(self.multi_agent_names)
        obs = self._env.reset()[0]
        assert (obs[self.multi_agent_names[0]] == obs[self.multi_agent_names[1]]).all()

        self.observation_space = self._env.observation_space(self.multi_agent_names[0])
        self.action_dim_per_agent = self._env.action_space(self.multi_agent_names[0]).shape[0]
        self.action_dim = self.action_dim_per_agent*self.num_agents
        self.action_space = gym.spaces.box.Box(
            -np.ones(self.action_dim, dtype=np.float64), 
            np.ones(self.action_dim, dtype=np.float64), 
            dtype=np.float64,
        )
        self.reward_space = gym.spaces.box.Box(
            -np.inf*np.ones(2, dtype=np.float64), 
            np.inf*np.ones(2, dtype=np.float64), 
            dtype=np.float64,
        )
        self.cost_space = gym.spaces.box.Box(
            -np.inf*np.ones(2, dtype=np.float64), 
            np.inf*np.ones(2, dtype=np.float64), 
            dtype=np.float64,
        )

    def reset(self, *, seed=None, options=None):
        obs = self._env.reset(seed=seed, options=options)[0]
        return self._processObs(obs), {}

    def step(self, action):
        action = self._processAction(action)
        obs, reward, cost, terminate, truncate, info = self._env.step(action)
        obs = self._processObs(obs)
        reward = np.array([
            reward[self.multi_agent_names[0]], 
            reward[self.multi_agent_names[1]], 
            cost[self.multi_agent_names[0]],
            cost[self.multi_agent_names[1]]
        ])
        terminate = terminate[self.multi_agent_names[0]] or terminate[self.multi_agent_names[1]]
        truncate = truncate[self.multi_agent_names[0]] or truncate[self.multi_agent_names[1]]
        return obs, reward, terminate, truncate, {}

    def render(self, **args):
        return self._env.render(**args)

    def close(self):
        self._env.close()
    
    def _processObs(self, obs):
        return obs[self.multi_agent_names[0]]
    
    def _processAction(self, action):
        action_dict = {}
        for i in range(self.num_agents):
            action_dict[self.multi_agent_names[i]] = \
                action[i*self.action_dim_per_agent:(i+1)*self.action_dim_per_agent]
        return action_dict

class MASafetyGymPointGoalEnv(MASafetyGymEnv):
    def __init__(self, **args):
        super().__init__('SafetyPointMultiGoal1-v0', **args)

class MASafetyGymCarGoalEnv(MASafetyGymEnv):
    def __init__(self, **args):
        super().__init__('SafetyCarMultiGoal1-v0', **args)
