import time
import numpy as np

try:
    import gymnasium as gym

    old_api = False

except Exception:
    import gym

    old_api = True


from src.wrappers.common import TimeStep


class EpisodeMonitor(gym.ActionWrapper):
    """A class that computes episode returns and lengths."""

    def __init__(self, env: gym.Env):
        super().__init__(env)
        self._reset_stats()
        self.total_timesteps = 0

    def _reset_stats(self):
        self.reward_sum = 0.0
        self.episode_length = 0
        self.start_time = time.time()

    def step(self, action: np.ndarray) -> TimeStep:
        if old_api:
            observation, reward, done, info = self.env.step(action)
        else:
            observation, reward, terminated, truncated, info = self.env.step(
                action
            )
            done = terminated or truncated

        self.reward_sum += reward
        self.episode_length += 1
        self.total_timesteps += 1
        info["total"] = {"timesteps": self.total_timesteps}

        if done:
            info["episode"] = {}
            info["episode"]["return"] = self.reward_sum
            info["episode"]["length"] = self.episode_length
            info["episode"]["duration"] = time.time() - self.start_time

            if hasattr(self, "get_normalized_score"):
                info["episode"]["return"] = (
                    self.get_normalized_score(info["episode"]["return"])
                    * 100.0
                )

        # DONOTCHANGE: need to return with same API for compatability
        if old_api:
            return observation, reward, done, info
        else:
            return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None) -> np.ndarray:
        self._reset_stats()
        return (
            self.env.reset()
            if seed is None
            else self.env.reset(seed=seed, options=options)
        )
