import numpy as np
import itertools
from gym import Env
from gym.spaces import Box
from gym.spaces import Discrete

from collections import deque
import copy


class ProxyEnv(Env):

    def __init__(self, wrapped_env):
        self._wrapped_env = wrapped_env
        self.action_space = self._wrapped_env.action_space
        self.observation_space = self._wrapped_env.observation_space

    @property
    def wrapped_env(self):
        return self._wrapped_env

    def reset(self, **kwargs):
        return self._wrapped_env.reset(**kwargs)

    def step(self, action):
        return self._wrapped_env.step(action)

    def render(self, *args, **kwargs):
        return self._wrapped_env.render(*args, **kwargs)

    @property
    def horizon(self):
        return self._wrapped_env.horizon

    def terminate(self):
        if hasattr(self.wrapped_env, "terminate"):
            self.wrapped_env.terminate()

    def __getattr__(self, attr):
        if attr == '_wrapped_env':
            raise AttributeError()
        return getattr(self._wrapped_env, attr)

    def __getstate__(self):
        """
        This is useful to override in case the wrapped env has some funky
        __getstate__ that doesn't play well with overriding __getattr__.

        The main problematic case is/was gym's EzPickle serialization scheme.
        :return:
        """
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)

    def __str__(self):
        return '{}({})'.format(type(self).__name__, self.wrapped_env)


class HistoryEnv(ProxyEnv, Env):
    def __init__(self, wrapped_env, history_len):
        super().__init__(wrapped_env)
        self.history_len = history_len

        high = np.inf * np.ones(
            self.history_len * self.observation_space.low.size)
        low = -high
        self.observation_space = Box(low=low,
                                     high=high,
                                     )
        self.history = deque(maxlen=self.history_len)

    def step(self, action):
        state, reward, done, info = super().step(action)
        self.history.append(state)
        flattened_history = self._get_history().flatten()
        return flattened_history, reward, done, info

    def reset(self, **kwargs):
        state = super().reset()
        self.history = deque(maxlen=self.history_len)
        self.history.append(state)
        flattened_history = self._get_history().flatten()
        return flattened_history

    def _get_history(self):
        observations = list(self.history)

        obs_count = len(observations)
        for _ in range(self.history_len - obs_count):
            dummy = np.zeros(self._wrapped_env.observation_space.low.size)
            observations.append(dummy)
        return np.c_[observations]


class SwapColorEnv(ProxyEnv, Env):

    """
    Observation space is of form (H, W, C). Move C dimension to front.
    """

    def __init__(self, wrapped_env):
        super().__init__(wrapped_env)

        self._obs_shape = wrapped_env.observation_space.shape

        high = np.inf * np.ones(self._obs_shape).transpose((2, 0, 1))
        low = -high

        self.observation_space = Box(low=low, high=high)

    def obs_proc(self, s):
        return s.transpose((2, 0, 1)) / 255

    def step(self, action):
        s, r, d, info = super().step(action)
        return self.obs_proc(s), r, d, info

    def reset(self):
        s = super().reset()
        return self.obs_proc(s)


class DiscretizeEnv(ProxyEnv, Env):
    def __init__(self, wrapped_env, num_bins):
        super().__init__(wrapped_env)
        low = self.wrapped_env.action_space.low
        high = self.wrapped_env.action_space.high
        action_ranges = [
            np.linspace(low[i], high[i], num_bins)
            for i in range(len(low))
        ]
        self.idx_to_continuous_action = [
            np.array(x) for x in itertools.product(*action_ranges)
        ]
        self.action_space = Discrete(len(self.idx_to_continuous_action))

    def step(self, action):
        continuous_action = self.idx_to_continuous_action[action]
        return super().step(continuous_action)


class NormalizedBoxEnv(ProxyEnv):
    """
    Normalize action to in [-1, 1].

    Optionally normalize observations and scale reward.
    """

    def __init__(
            self,
            env,
            reward_scale=1.,
            obs_mean=None,
            obs_std=None,
            rew_mean=None,
            rew_std=None
    ):
        ProxyEnv.__init__(self, env)
        self._should_normalize = not (obs_mean is None and obs_std is None)
        self._shold_normalize_rew = not (rew_mean is None and rew_std is None)
        if self._should_normalize:
            obs_mean = np.zeros_like(env.observation_space.low) if obs_mean is None else np.array(obs_mean)
            obs_std = np.ones_like(env.observation_space.low) if obs_std is None else np.array(obs_std)
        if self._shold_normalize_rew:
            rew_mean = np.zeros(1) if rew_mean is None else np.array(rew_mean)
            rew_std = np.ones(1) if rew_std is None else np.array(rew_std)
        self._reward_scale = reward_scale
        self._obs_mean, self._obs_std = obs_mean, obs_std
        self._rew_mean, self._rew_std = rew_mean, rew_std
        ub = np.ones(self._wrapped_env.action_space.shape)
        self.action_space = Box(-1 * ub, ub)

    def estimate_obs_stats(self, obs_batch, override_values=False):
        if self._obs_mean is not None and not override_values:
            raise Exception("Observation mean and std already set. To "
                            "override, set override_values to True.")
        self._obs_mean = np.mean(obs_batch, axis=0)
        self._obs_std = np.std(obs_batch, axis=0)

    def _apply_normalize_obs(self, obs):
        return (obs - self._obs_mean) / (self._obs_std + 1e-8)
    
    def _apply_normalize_rew(self, rew):
        return (rew - self.rew_mean) / (self._rew_std + 1e-8)

    def step(self, action):
        lb = self._wrapped_env.action_space.low
        ub = self._wrapped_env.action_space.high
        scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
        scaled_action = np.clip(scaled_action, lb, ub)

        wrapped_step = self._wrapped_env.step(scaled_action)
        next_obs, reward, done, info = wrapped_step
        if self._should_normalize:
            next_obs = self._apply_normalize_obs(next_obs)
        if self._shold_normalize_rew:
            reward = self._apply_normalize_rew(reward)
        
        return next_obs, reward * self._reward_scale, done, info
    
    def reset(self):
        s = self.wrapped_env.reset()
        return self._apply_normalize_obs(s)

    def __str__(self):
        return "Normalized: %s" % self._wrapped_env


class NonTerminatingEnv(ProxyEnv):

    def step(self, action):
        s, r, d, inf = self._wrapped_env.step(action)
        return s, r, False, inf

    def __str__(self):
        return 'Non-terminating: %s' % self._wrapped_env


class ContinualLifelongEnv(ProxyEnv):

    def __init__(self, wrapped_env, switch_every, envs_list):
        super().__init__(wrapped_env)

        self.switch_every = switch_every
        self.envs_list = envs_list

        self._n_timesteps, self._ptr = 0, 0

    def step(self, action):
        transition = super().step(action)
        self._n_timesteps += 1
        if self._n_timesteps % self.switch_every:
            self.advance_env()
        return transition

    def advance_env(self):
        env_state = copy.deepcopy(self._wrapped_env.get_env_state())
        self._ptr = (self._ptr + 1) % len(self.envs_list)
        self._wrapped_env = self.envs_list[self._ptr]
        self._wrapped_env.set_env_state(env_state)


class FollowerEnv(ProxyEnv):

    def __init__(self, env_to_follow):
        super().__init__(copy.deepcopy(env_to_follow))

        self.parent_env = env_to_follow

    def reset(self):
        self._wrapped_env = copy.deepcopy(self.parent_env)
        return super().reset()
