# Copyright (c) Anonymous Organization.
# All rights reserved.
# The below code is inspired from TD-MPC2 https://github.com/nicklashansen/tdmpc2
# licensed under the MIT License
"""
Wrapper for limiting the time steps of an environment.
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
"""

from typing import Optional

import gym


class TimeLimit(gym.Wrapper):
    """This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.

    Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
    :class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
    This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
    The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
    the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.

    Example:
       >>> from gym.envs.classic_control import CartPoleEnv
       >>> from gym.wrappers import TimeLimit
       >>> env = CartPoleEnv()
       >>> env = TimeLimit(env, max_episode_steps=1000)
    """

    def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
        """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.

        Args:
            env: The environment to apply the wrapper
            max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
        """
        super().__init__(env)
        if max_episode_steps is None and self.env.spec is not None:
            max_episode_steps = env.spec.max_episode_steps
        if self.env.spec is not None:
            self.env.spec.max_episode_steps = max_episode_steps
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = None

    def elapsed_steps(self):
        """Returns the number of steps elapsed"""
        return self._elapsed_steps

    def steps_left(self):
        """Returns the number of steps left"""
        return self._max_episode_steps - self._elapsed_steps

    def set_max_steps(self, max_episode_steps: int):
        self._max_episode_steps = max_episode_steps
        if self.env.spec is not None:
            self.env.spec.max_episode_steps = max_episode_steps

    def set_elapsed_steps(self, elapsed_steps: int):
        """Sets the number of steps elapsed"""
        if elapsed_steps < 0:
            raise ValueError("Elapsed steps cannot be negative")
        self._elapsed_steps = elapsed_steps

    def max_steps(self):
        """Returns the maximum number of steps allowed"""
        return self._max_episode_steps

    def step(self, action):
        """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.

        Args:
            action: The environment step action

        Returns:
            The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
            when truncated (the number of steps elapsed >= max episode steps) or
            "TimeLimit.truncated"=False if the environment terminated
        """
        observation, reward, done, _, info = self.env.step(action)
        self._elapsed_steps += 1
        episode_truncated = False
        if self._elapsed_steps >= self._max_episode_steps:
            # TimeLimit.truncated key may have been already set by the environment
            # do not overwrite it
            episode_truncated = not done or info.get("TimeLimit.truncated", False)
            info["TimeLimit.truncated"] = episode_truncated
            done = True
        return observation, reward, done, episode_truncated, info

    def reset(self, **kwargs):
        """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.

        Args:
            **kwargs: The kwargs to reset the environment with

        Returns:
            The reset environment
        """
        self._elapsed_steps = 0
        return self.env.reset(**kwargs)
