import gymnasium
import gymnasium as gym
import numpy as np
from collections import deque
import safety_gymnasium as safety_gym
import time
from typing import Optional

class SafeClipAction(gymnasium.ActionWrapper, gymnasium.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gymnasium.Env,
        min_action,
        max_action,
    ) -> None:
        gymnasium.utils.RecordConstructorArgs.__init__(
            self,
            min_action=-np.inf,
            max_action=np.inf,
        )
        gymnasium.ActionWrapper.__init__(self, env)

        self.min_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
        )
        self.max_action = (
            np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
        )

    def action(self, action):
        return action.clip(self.min_action,self.max_action)

class SafeTanHAction(gymnasium.ActionWrapper, gymnasium.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gymnasium.Env,
    ) -> None:
        gymnasium.utils.RecordConstructorArgs.__init__(
            self,
            min_action=-np.inf,
            max_action=np.inf,
        )
        gymnasium.ActionWrapper.__init__(self, env)

        self.min_action = (
            -1
        )
        self.max_action = (
            1
        )

    def action(self, action):
        return np.tanh(action)
    
class ObservationStack(gymnasium.ObservationWrapper, gymnasium.utils.RecordConstructorArgs):

    def __init__(
        self,
        env: gymnasium.Env,
        n_stack: int
    ) -> None:
        gymnasium.utils.RecordConstructorArgs.__init__(
            self,
            min_action=-np.inf,
            max_action=np.inf,
        )
        gymnasium.ObservationWrapper.__init__(self, env)
        self.n_stack = n_stack
        self.last_frames = deque(list(),self.n_stack)
    
    def step(self, action):
        """Steps through the environment and normalizes the observation."""
        obs, rews, costs, terminateds, truncateds, infos = self.env.step(action)
        obs = self.observation(obs)
        return obs, rews, costs, terminateds, truncateds, infos

    def observation(self, obs):
        self.last_frames.append(obs)
        return np.concatenate(self.last_frames)



class SafeRecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
    """This wrapper will keep track of cumulative rewards and episode lengths.

    At the end of an episode, the statistics of the episode will be added to ``info``
    using the key ``episode``. If using a vectorized environment also the key
    ``_episode`` is used which indicates whether the env at the respective index has
    the episode statistics.

    After the completion of an episode, ``info`` will look like this::

        >>> info = {
        ...     "episode": {
        ...         "r": "<cumulative reward>",
        ...         "l": "<episode length>",
        ...         "t": "<elapsed time since beginning of episode>"
        ...     },
        ... }

    For a vectorized environments the output will be in the form of::

        >>> infos = {
        ...     "final_observation": "<array of length num-envs>",
        ...     "_final_observation": "<boolean array of length num-envs>",
        ...     "final_info": "<array of length num-envs>",
        ...     "_final_info": "<boolean array of length num-envs>",
        ...     "episode": {
        ...         "r": "<array of cumulative reward>",
        ...         "l": "<array of episode length>",
        ...         "t": "<array of elapsed time since beginning of episode>"
        ...     },
        ...     "_episode": "<boolean array of length num-envs>"
        ... }

    Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
    :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.

    Attributes:
        return_queue: The cumulative rewards of the last ``deque_size``-many episodes
        length_queue: The lengths of the last ``deque_size``-many episodes
    """

    def __init__(self, env: gym.Env, deque_size: int = 100):
        """This wrapper will keep track of cumulative rewards and episode lengths.

        Args:
            env (Env): The environment to apply the wrapper
            deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
        """
        gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
        gym.Wrapper.__init__(self, env)

        try:
            self.num_envs = self.get_wrapper_attr("num_envs")
            self.is_vector_env = self.get_wrapper_attr("is_vector_env")
        except AttributeError:
            self.num_envs = 1
            self.is_vector_env = False

        self.episode_count = 0
        self.episode_start_times: np.ndarray = None
        self.episode_returns: Optional[np.ndarray] = None
        self.episode_lengths: Optional[np.ndarray] = None
        self.episode_cost : Optional[np.ndarray] = None
        self.return_queue = deque(maxlen=deque_size)
        self.cost_queue = deque(maxlen=deque_size)
        self.length_queue = deque(maxlen=deque_size)

    def reset(self, **kwargs):
        """Resets the environment using kwargs and resets the episode returns and lengths."""
        obs, info = super().reset(**kwargs)
        self.episode_start_times = np.full(
            self.num_envs, time.perf_counter(), dtype=np.float32
        )
        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        self.episode_cost = np.zeros(self.num_envs, dtype=np.float32)
        return obs, info

    def step(self, action):
        """Steps through the environment, recording the episode statistics."""
        (
            observations,
            rewards,
            cost,
            terminations,
            truncations,
            infos,
        ) = self.env.step(action)
        assert isinstance(
            infos, dict
        ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
        self.episode_returns += rewards
        self.episode_lengths += 1
        self.episode_cost += cost
        dones = np.logical_or(terminations, truncations)
        num_dones = np.sum(dones)
        if num_dones:
            if "episode" in infos or "_episode" in infos:
                raise ValueError(
                    "Attempted to add episode stats when they already exist"
                )
            else:
                infos["episode"] = {
                    "r": np.where(dones, self.episode_returns, 0.0),
                    "c": np.where(dones, self.episode_cost, 0.0),
                    "l": np.where(dones, self.episode_lengths, 0),
                    "t": np.where(
                        dones,
                        np.round(time.perf_counter() - self.episode_start_times, 6),
                        0.0,
                    ),
                }
                if self.is_vector_env:
                    infos["_episode"] = np.where(dones, True, False)
            self.return_queue.extend(self.episode_returns[dones])
            self.length_queue.extend(self.episode_lengths[dones])
            self.cost_queue.extend(self.episode_cost[dones])
            self.episode_count += num_dones
            self.episode_lengths[dones] = 0
            self.episode_returns[dones] = 0
            self.episode_start_times[dones] = time.perf_counter()
        return (
            observations,
            rewards,
            cost,
            terminations,
            truncations,
            infos,
        )