import time
from collections import deque
import gym
import numpy as np
from gym.spaces import Box
from copy import deepcopy
import logging


class ObsToNumpy(gym.ObservationWrapper):
    def observation(self, observation) -> np.ndarray:
        return np.array(observation)


class RecordEpisodeStatistics(gym.Wrapper):
    def __init__(self, env, deque_size=100):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.t0 = time.perf_counter()
        self.episode_count = 0
        self.episode_returns = None
        self.episode_lengths = None
        self.return_queue = deque(maxlen=deque_size)
        self.length_queue = deque(maxlen=deque_size)
        self.is_vector_env = getattr(env, "is_vector_env", False)

    def reset(self, **kwargs):
        observations = super().reset(**kwargs)
        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        return observations

    def step(self, action):
        observations, rewards, dones, infos = super().step(action)
        self.episode_returns += rewards
        self.episode_lengths += 1
        if not self.is_vector_env:
            infos = [infos]
            dones = [dones]
        else:
            infos = list(infos)  # Convert infos to mutable type
        for i in range(len(dones)):
            if dones[i]:
                infos[i] = infos[i].copy()
                episode_return = self.episode_returns[i]
                episode_length = self.episode_lengths[i]
                episode_info = {
                    "r": episode_return,
                    "l": episode_length,
                    "t": round(time.perf_counter() - self.t0, 6),
                }
                infos[i]["episode"] = episode_info
                self.return_queue.append(episode_return)
                self.length_queue.append(episode_length)
                self.episode_count += 1
                self.episode_returns[i] = 0
                self.episode_lengths[i] = 0
        if self.is_vector_env:
            infos = tuple(infos)
        return (
            observations,
            rewards,
            dones if self.is_vector_env else dones[0],
            infos if self.is_vector_env else infos[0],
        )


class RecordStepStatistics(gym.Wrapper):
    # For env with no timelimit wrapper
    def __init__(self, env, deque_size=100, update_info_every: int = 1_000):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.t0 = time.perf_counter()
        self.episode_count = 0
        self.episode_returns = None
        self.episode_lengths = None
        self.return_queue = deque(maxlen=deque_size)
        self.length_queue = deque(maxlen=deque_size)
        self.is_vector_env = getattr(env, "is_vector_env", False)
        self.update_info_every = update_info_every

    def reset(self, **kwargs):
        observations = super().reset(**kwargs)
        self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
        self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
        return observations

    def step(self, action):
        observations, rewards, dones, infos = super().step(action)
        self.episode_returns += rewards
        self.episode_lengths += 1
        if not self.is_vector_env:
            infos = [infos]
            dones = [dones]
        else:
            infos = list(infos)  # Convert infos to mutable type
        for i in range(len(dones)):
            if self.episode_lengths[i] % self.update_info_every == 0:
                infos[i] = infos[i].copy()
                episode_return = self.episode_returns[i]
                episode_length = self.episode_lengths[i]
                episode_info = {
                    "r": episode_return,
                    "l": episode_length,
                    "t": round(time.perf_counter() - self.t0, 6),
                }
                infos[i]["episode"] = episode_info
                self.return_queue.append(episode_return)
                self.length_queue.append(episode_length)
                self.episode_count += 1
                self.episode_returns[i] = 0
                self.episode_lengths[i] = 0
        if self.is_vector_env:
            infos = tuple(infos)
        return (
            observations,
            rewards,
            dones if self.is_vector_env else dones[0],
            infos if self.is_vector_env else infos[0],
        )


class NormalizedBoxEnv(gym.Wrapper):
    """
    Normalize action to in [-1, 1].
    Optionally normalize observations and scale reward.

    """

    def __init__(
        self,
        env,
        reward_scale=1.0,
        obs_mean=None,
        obs_std=None,
    ):
        super().__init__(env)
        self._should_normalize = not (obs_mean is None and obs_std is None)
        if self._should_normalize:
            if obs_mean is None:
                obs_mean = np.zeros_like(env.observation_space.low)
            else:
                obs_mean = np.array(obs_mean)
            if obs_std is None:
                obs_std = np.ones_like(env.observation_space.low)
            else:
                obs_std = np.array(obs_std)
        self._reward_scale = reward_scale
        self._obs_mean = obs_mean
        self._obs_std = obs_std
        ub = np.ones(self.env.action_space.shape)
        self.action_space = Box(-1 * ub, ub)

    def estimate_obs_stats(self, obs_batch, override_values=False):
        if self._obs_mean is not None and not override_values:
            raise Exception(
                "Observation mean and std already set. To "
                "override, set override_values to True."
            )
        self._obs_mean = np.mean(obs_batch, axis=0)
        self._obs_std = np.std(obs_batch, axis=0)

    def _apply_normalize_obs(self, obs):
        return (obs - self._obs_mean) / (self._obs_std + 1e-8)

    def step(self, action):
        lb = self._wrapped_env.action_space.low
        ub = self._wrapped_env.action_space.high
        scaled_action = lb + (action + 1.0) * 0.5 * (ub - lb)
        scaled_action = np.clip(scaled_action, lb, ub)

        wrapped_step = self._wrapped_env.step(scaled_action)
        next_obs, reward, done, info = wrapped_step
        if self._should_normalize:
            next_obs = self._apply_normalize_obs(next_obs)
        return next_obs, reward * self._reward_scale, done, info


class FakeSequentialEnv(gym.Wrapper):
    def __init__(self, env_fns, num_envs=1):

        env_fns_copy = deepcopy(env_fns)
        vec_env = gym.vector.SyncVectorEnv(env_fns_copy)
        self.action_space = deepcopy(vec_env.action_space)
        self.observation_space = deepcopy(vec_env.observation_space)
        del vec_env
        del env_fns_copy
        self.env = env_fns[0]()
        self.is_vector_env = True

    def reset(self):
        obs, info = self.env.reset()
        obs = np.expand_dims(obs, axis=0)
        return obs, [info]

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        obs = np.expand_dims(obs, axis=0)
        reward = np.array([reward])
        done = np.array([done])
        info = [info]
        return obs, reward, done, info


class RoundAction(gym.Wrapper):
    """Round action to a given precision. The reason is between the CPU and GPU,
    the action may have some small difference, which may cause the action to be
    different. This wrapper can help to avoid this problem.
    Args:
        gym (gym.Env): Gym environment.
        precision (int): Precision of the action rounding.
    """

    def __init__(self, env, precision: int = 7):
        super().__init__(env)
        self.precision = precision
        if not isinstance(env.action_space, Box):
            logging.warning("RoundAction only works for Box action space.")

    def step(self, action):
        action_round = np.around(action, self.precision)
        return self.env.step(action_round)
