# =============================================================================
# MIT License

# Copyright (c) 2023 Reinforcement Learning Evolution Foundation

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# =============================================================================

from collections import deque
from typing import Any, Callable, Dict, Optional, Tuple, Union

try:
    import envpool
except:
    pass
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
from gymnasium.wrappers import RecordEpisodeStatistics

GymObs = Union[th.Tensor, Dict[str, th.Tensor]]


class EnvPoolAsync2Gymnasium(gym.Wrapper):
    """Create an `EnvPool` environment with asynchronous mode, and wrap it
        to allow a modular transformation of the `step` and `reset` methods.

    Args:
        env_kwargs (Dict): Environment arguments.

    Returns:
        A `Gymnasium`-like environment.
    """

    def __init__(self, env_kwargs: Dict) -> None:
        envs = envpool.make(**env_kwargs)
        super().__init__(envs)
        self.num_envs = env_kwargs.get("num_envs", 1)
        self.is_vector_env = True

    def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
        """Reset the environment."""
        # send the initial reset signal to all envs
        self.env.async_reset()
        obs, rew, term, trunc, info = self.env.recv()
        # run one step to get the initial observation
        self.env.send(np.zeros(shape=(self.num_envs, *self.action_space.shape)), info["env_id"])
        return obs, info

    def step(self, actions: int) -> Tuple[Any, float, bool, bool, Dict]:
        """Step the environment.

        Args:
            actions (int): Action to take.

        Returns:
            Observation, reward, terminated, truncated, info.
        """
        obs, rew, term, trunc, info = self.env.recv()
        self.env.send(actions, info["env_id"])

        return obs, rew, term, trunc, info


class EnvPoolSync2Gymnasium(gym.Wrapper):
    """Wraps an `EnvPool` environment with synchronous mode to allow
        a modular transformation of the `step` and `reset` methods.

    Args:
        env_kwargs (Dict): Environment arguments.

    Returns:
        A `Gymnasium`-like environment.
    """

    def __init__(self, env_kwargs: Dict) -> None:
        envs = envpool.make(**env_kwargs)
        super().__init__(envs)
        self.num_envs = env_kwargs.get("num_envs", 1)
        self.is_vector_env = True

    def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
        """Reset the environment with `envpool`."""
        return self.env.reset()

    def step(self, actions: int) -> Tuple[Any, float, bool, bool, Dict]:
        """Step the environment with `envpool`.

        Args:
            actions (int): Action to take.

        Returns:
            Observation, reward, terminated, truncated, info.
        """
        return self.env.step(actions)


class Gymnasium2Torch(gym.Wrapper):
    """Env wrapper for processing gymnasium environments and outputting torch tensors.

    Args:
        env (VectorEnv): The vectorized environments.
        device (str): Device (cpu, cuda, ...) on which the code should be run.
        envpool (bool): Whether to use `EnvPool` env.

    Returns:
        Gymnasium2Torch wrapper.
    """

    def __init__(self, env: VectorEnv, device: str, envpool: bool = False) -> None:
        super().__init__(env)
        self.num_envs = env.unwrapped.num_envs
        self.device = th.device(device)

        # envpool's observation space and action space are the same as the single env.
        if not envpool:
            self.observation_space = env.single_observation_space
            self.action_space = env.single_action_space

        if isinstance(self.observation_space, gym.spaces.Dict):
            self._format_obs = lambda x: {key: th.as_tensor(item, device=self.device) for key, item in x.items()}
        else:
            self._format_obs = lambda x: th.as_tensor(x, device=self.device)

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[GymObs, Dict]:
        """Reset all environments and return a batch of initial observations and info.

        Args:
            seed (int): The environment reset seeds.
            options (Optional[dict]): If to return the options.

        Returns:
            First observations and info.
        """
        obs, infos = self.env.reset(seed=seed, options=options)

        return self._format_obs(obs), infos

    def step(self, actions: th.Tensor) -> Tuple[GymObs, th.Tensor, th.Tensor, th.Tensor, Dict[str, Any]]:
        """Take an action for each environment.

        Args:
            actions (th.Tensor): element of :attr:`action_space` Batch of actions.

        Returns:
            Next observations, rewards, terminateds, truncateds, infos.
        """
        new_observations, rewards, terminateds, truncateds, infos = self.env.step(actions.cpu().numpy())
        # TODO: get real next observations
        # for idx, (term, trunc) in enumerate(zip(terminateds, truncateds)):
        #     if term or trunc:
        #         new_obs[idx] = info['final_observation'][idx]

        # convert to tensor
        rewards = th.as_tensor(rewards, dtype=th.float32, device=self.device)

        terminateds = th.as_tensor(
            [1.0 if _ else 0.0 for _ in terminateds],
            dtype=th.float32,
            device=self.device,
        )
        truncateds = th.as_tensor(
            [1.0 if _ else 0.0 for _ in truncateds],
            dtype=th.float32,
            device=self.device,
        )

        return self._format_obs(new_observations), rewards, terminateds, truncateds, infos


class FrameStack(gym.Wrapper):
    """Observation wrapper that stacks the observations in a rolling manner.

    Args:
        env (gym.Env): Environment to wrap.
        k (int): Number of stacked frames.

    Returns:
        FrameStackEnv instance.
    """

    def __init__(self, env: gym.Env, k: int) -> None:
        super().__init__(env)
        self._k = k
        self._frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype,
        )

    def reset(self, **kwargs) -> Tuple[th.Tensor, Dict]:
        obs, info = self.env.reset()
        for _ in range(self._k):
            self._frames.append(obs)
        return self._get_obs(), info

    def step(self, action: Tuple[float]) -> Tuple[Any, float, bool, bool, Dict]:
        obs, reward, terminated, truncated, info = self.env.step(action)
        self._frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self) -> np.ndarray:
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)


class DistributedWrapper:
    """An env wrapper to adapt to the distributed trainer.

    Args:
        env (gym.Env): A Gym-like env.

    Returns:
        Processed env.
    """

    def __init__(self, env: gym.Env) -> None:
        self.env = env
        self.episode_return = None
        self.episode_step = None
        if env.action_space.__class__.__name__ == "Discrete":
            self.action_type = "Discrete"
            self.action_dim = 1
        elif env.action_space.__class__.__name__ == "Box":
            self.action_type = "Box"
            self.action_dim = env.action_space.shape[0]
        else:
            raise NotImplementedError("Unsupported action type!")

    def reset(self, seed) -> Dict[str, th.Tensor]:
        """Reset the environment."""
        init_reward = th.zeros(1, 1)
        init_last_action = th.zeros(1, self.action_dim, dtype=th.int64)
        self.episode_return = th.zeros(1, 1)
        self.episode_step = th.zeros(1, 1, dtype=th.int32)
        init_terminated = th.ones(1, 1, dtype=th.uint8)
        init_truncated = th.ones(1, 1, dtype=th.uint8)

        obs, info = self.env.reset(seed=seed)
        obs = self._format_obs(obs)

        return dict(
            observations=obs,
            rewards=init_reward,
            terminateds=init_terminated,
            truncateds=init_truncated,
            episode_returns=self.episode_return,
            episode_steps=self.episode_step,
            last_actions=init_last_action,
        )

    def step(self, action: th.Tensor) -> Dict[str, th.Tensor]:
        """Step function that returns a dict consists of the current and history observation and action.

        Args:
            action (th.Tensor): Action tensor.

        Returns:
            Step information dict.
        """
        if self.action_type == "Discrete":
            _action = action.item()
        elif self.action_type == "Box":
            _action = action.squeeze(0).cpu().numpy()
        else:
            raise NotImplementedError("Unsupported action type!")

        obs, reward, terminated, truncated, info = self.env.step(_action)
        self.episode_step += 1
        self.episode_return += reward
        episode_step = self.episode_step
        episode_return = self.episode_return
        if terminated or truncated:
            obs, info = self.env.reset()
            self.episode_return = th.zeros(1, 1)
            self.episode_step = th.zeros(1, 1, dtype=th.int32)

        obs = self._format_obs(obs)
        reward = th.as_tensor(reward, dtype=th.float32).view(1, 1)
        terminated = th.as_tensor(terminated, dtype=th.uint8).view(1, 1)
        truncated = th.as_tensor(truncated, dtype=th.uint8).view(1, 1)

        return dict(
            observations=obs,
            rewards=reward,
            terminateds=terminated,
            truncateds=truncated,
            episode_returns=episode_return,
            episode_steps=episode_step,
            last_actions=action,
        )

    def close(self) -> None:
        """Close the environment."""
        self.env.close()

    def _format_obs(self, obs: np.ndarray) -> th.Tensor:
        """Reformat the observation by adding an time dimension.

        Args:
            obs (np.ndarray): Observation.

        Returns:
            Formatted observation.
        """
        obs = th.from_numpy(np.array(obs))
        return obs.view((1, 1, *obs.shape))


def make_rllte_env(
    env_id: Union[str, Callable[..., gym.Env]],
    num_envs: int = 1,
    seed: int = 1,
    device: str = "cpu",
    asynchronous: bool = True,
    env_kwargs: Optional[Dict[str, Any]] = None,
) -> Gymnasium2Torch:
    """Create environments that adapt to rllte engine.

    Args:
        env_id (Union[str, Callable[..., gym.Env]]): either the env ID, the env class or a callable returning an env
        num_envs (int): Number of environments.
        seed (int): Random seed.
        device (str): Device to convert data.
        asynchronous (bool): `True` for `AsyncVectorEnv` and `False` for `SyncVectorEnv`.
        env_kwargs: Optional keyword argument to pass to the env constructor.

    Returns:
        Environment wrapped by `TorchVecEnvWrapper`.
    """
    env_kwargs = env_kwargs or {}

    def make_env(rank: int) -> Callable:
        def _thunk() -> gym.Env:
            assert env_kwargs is not None
            if isinstance(env_id, str):
                # if the render mode was not specified, we set it to `rgb_array` as default.
                kwargs = {"render_mode": "rgb_array"}
                kwargs.update(env_kwargs)
                try:
                    env = gym.make(env_id, **kwargs)
                except Exception:
                    env = gym.make(env_id, **env_kwargs)
            else:
                env = env_id(**env_kwargs)

            env.action_space.seed(seed + rank)

            return env

        return _thunk

    env_fns = [make_env(rank=i) for i in range(num_envs)]
    if asynchronous:
        envs = AsyncVectorEnv(env_fns)
    else:
        envs = SyncVectorEnv(env_fns)

    envs = RecordEpisodeStatistics(envs)

    return Gymnasium2Torch(env=envs, device=device)
