from typing import Any
from typing import Union, Optional

from gymnasium.core import WrapperObsType

from misc.rng_modules import fix_seed
import gym as classic_gym
import gymnasium as gym
import numpy as np
from gym.core import ObsType


class NormalizedBoxEnv(classic_gym.Wrapper):
    """
    Normalize action to in [-1, 1].

    Optionally normalize observations and scale reward.
    """

    def __init__(
            self,
            env,
            reward_scale=1.,
            obs_mean=None,
            obs_std=None,
    ):
        super().__init__(env)
        self._should_normalize = not (obs_mean is None and obs_std is None)
        if self._should_normalize:
            if obs_mean is None:
                obs_mean = np.zeros_like(env.observation_space.low)
            else:
                obs_mean = np.array(obs_mean)
            if obs_std is None:
                obs_std = np.ones_like(env.observation_space.low)
            else:
                obs_std = np.array(obs_std)
        self._reward_scale = reward_scale
        self._obs_mean = obs_mean
        self._obs_std = obs_std
        ub = np.ones(self._wrapped_env.action_space.shape)
        self.action_space = classic_gym.spaces.Box(-1 * ub, ub)

    def estimate_obs_stats(self, obs_batch, override_values=False):
        if self._obs_mean is not None and not override_values:
            raise Exception("Observation mean and std already set. To "
                            "override, set override_values to True.")
        self._obs_mean = np.mean(obs_batch, axis=0)
        self._obs_std = np.std(obs_batch, axis=0)

    def reset(self, *, seed: Optional[int] = None) -> Union[ObsType, tuple[ObsType, dict]]:
        if seed is not None:
            fix_seed(seed)
        obs = super().reset()
        if self._should_normalize:
            obs = self._apply_normalize(obs)
        return obs

    def _apply_normalize_obs(self, obs):
        return (obs - self._obs_mean) / (self._obs_std + 1e-8)

    def step(self, action):
        lb = self._wrapped_env.action_space.low
        ub = self._wrapped_env.action_space.high
        scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
        scaled_action = np.clip(scaled_action, lb, ub)

        wrapped_step = self._wrapped_env.step(scaled_action)
        next_obs, reward, done, info = wrapped_step
        if self._should_normalize:
            next_obs = self._apply_normalize_obs(next_obs)
        return next_obs, reward * self._reward_scale, done, info

    def __str__(self):
        return "Normalized: %s" % self._wrapped_env


class NormalizedGymnasiumBoxEnv(gym.Wrapper):
    """
    Normalize action to in [-1, 1].

    Optionally normalize observations and scale reward.
    """

    def __init__(
            self,
            env,
            reward_scale=1.,
            obs_mean=None,
            obs_std=None,
    ):
        super().__init__(env)
        self._should_normalize = not (obs_mean is None and obs_std is None)
        if self._should_normalize:
            if obs_mean is None:
                obs_mean = np.zeros_like(env.observation_space.low)
            else:
                obs_mean = np.array(obs_mean)
            if obs_std is None:
                obs_std = np.ones_like(env.observation_space.low)
            else:
                obs_std = np.array(obs_std).clip(1e-12)

            obs_low = self.env.observation_space.low
            obs_high = self.env.observation_space.high
            new_low = (obs_low - obs_mean.flatten()) / obs_std.flatten()
            new_high = (obs_high - obs_mean.flatten()) / obs_std.flatten()
            self.observation_space = gym.spaces.Box(new_low, new_high, shape=self.env.observation_space.shape)
        else:
            self.observation_space = self.env.observation_space

        self._reward_scale = reward_scale
        self._obs_mean = obs_mean.squeeze().reshape(self.observation_space.shape)

        self._obs_std = obs_std.squeeze().reshape(self.observation_space.shape)

        ub = np.ones(self.env.action_space.shape)
        self.action_space = gym.spaces.Box(-1 * ub, ub)

        self._should_normalize_reward = False

    def reset(
        self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
    ) -> tuple[WrapperObsType, dict[str, Any]]:

        obs, info = super().reset(seed=seed, options=options)
        info['denormalized_obs'] = obs.copy()
        if self._should_normalize:
            obs = self._apply_normalize_obs(obs)

        return obs, info

    def _apply_normalize_obs(self, obs):
        return (obs.copy() - self._obs_mean) / self._obs_std

    def step(self, action):
        lb = self.env.action_space.low
        ub = self.env.action_space.high
        scaled_action = lb + (action.copy() + 1.) * 0.5 * (ub - lb)
        scaled_action = np.clip(scaled_action, lb, ub)
        wrapped_step = self.env.step(scaled_action)
        next_obs, reward, done, timeout, info = wrapped_step

        info['denormalized_obs'] = next_obs.copy()
        if self._should_normalize:
            next_obs = self._apply_normalize_obs(next_obs)
        return next_obs.copy(), reward * self._reward_scale, done, timeout, info

    def __str__(self):
        return "Normalized: %s" % self.env