import gymnasium
from gym import spaces
import numpy as np


class ObsErrorNoiseWrapper(gymnasium.ObservationWrapper):
    def __init__(self, env, obs_noise):
        super().__init__(env)
        self.obs_noise = obs_noise

    def observation(self, obs):
        noise = self.obs_noise * np.random.randn(obs.shape[0])
        obs = obs + noise
        return obs


class ObsErrorHiddenDims(gymnasium.ObservationWrapper):
    def __init__(self, env, obs_hidden_dim):
        super().__init__(env)
        self.obs_hidden_dim = obs_hidden_dim

    def observation(self, obs):
        # for obs_hidden_dim in self.obs_hidden_dims:
        obs[:, self.obs_hidden_dim:] = 0.0
        return obs


class ObsErrorCars(gymnasium.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gymnasium.spaces.Box(low=env.observation_space.low[0,0],
                                                      high=env.observation_space.high[0,0],
                                                      shape=(20,5))
        # self.obs_hidden_dim = obs_hidden_dim

    def observation(self, obs):
        # for obs_hidden_dim in self.obs_hidden_dims:
        extended_obs = np.zeros((20, 5))
        extended_obs[:obs.shape[0], :] = obs
        return extended_obs


