import gym
import numpy as np


class ObsErrorNoiseWrapper(gym.ObservationWrapper):
    def __init__(self, env, obs_noise):
        super().__init__(env)
        self.obs_noise = obs_noise

    def observation(self, obs):
        noise = self.obs_noise * np.random.randn(obs.shape[0])
        obs = obs + noise
        return obs


class ObsErrorHiddenDims(gym.ObservationWrapper):
    def __init__(self, env, obs_hidden_dims):
        super().__init__(env)
        if type(obs_hidden_dims) == int:
            self.obs_hidden_dims = [obs_hidden_dims]
        elif type(obs_hidden_dims) == list:
            self.obs_hidden_dims = obs_hidden_dims
        else:
            raise TypeError('obs_hidden_dims should be either int or list of ints')

    def observation(self, obs):
        for obs_hidden_dim in self.obs_hidden_dims:
            if 0 <= obs_hidden_dim < obs.shape[0]:
                obs[obs_hidden_dim] = 0.0
            else:
                raise ValueError(
                    f'Error: accepted self.obs_hidden_dims values are between 0 and obs.shape[0]. (self.obs_hidden_dims = {self.obs_hidden_dims}, obs.shape[0 = {obs.shape[0]})')

        return obs


# class ObsErrorHiddenDims(gym.ObservationWrapper):
#     def __init__(self, env, obs_hidden_dims):
#         super().__init__(env)
#         # self.observation_space = gym.spaces.Box(0, 1, (self.n,))
#         if type(obs_hidden_dims) == int:
#             self.obs_hidden_dims = [obs_hidden_dims]
#         elif type(obs_hidden_dims) == list:
#             self.obs_hidden_dims = obs_hidden_dims
#         else:
#             raise TypeError('obs_hidden_dims should be either int or list of ints')
#
#         self.observation_space = gym.spaces.Box(env.observation_space.low,
#                                                 env.observation_space.high,
#                                                 (env.observation_space.shape[0] - len(self.obs_hidden_dims),))
#
#         if max(obs_hidden_dims) >= self.observation_space.shape[0] or min(obs_hidden_dims) < 0:
#             raise ValueError('obs_hidden_dims should have values from 0 to the new observation shape')
#
#
#     def observation(self, obs):
#         obs = np.delete(obs, self.obs_hidden_dims)
#         # for obs_hidden_dim in self.obs_hidden_dims:
#         #     if 0 <= obs_hidden_dim < obs.shape[0]:
#         #         obs[obs_hidden_dim] = 0.0
#         #     else:
#         #         raise ValueError(
#         #             f'Error: accepted self.obs_hidden_dims values are between 0 and obs.shape[0]. (self.obs_hidden_dims = {self.obs_hidden_dims}, obs.shape[0 = {obs.shape[0]})')
#
#         return obs
