import gym, warnings
from action_masking.sb3_contrib.common.utils import fetch_fn


class InformerWrapper(gym.Wrapper):

    def __init__(self,
                 env: gym.Env,
                 alter_action_space=None,
                 transform_action_space_fn=None):

        super().__init__(env)

        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None:
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")

    def step(self, action):

        if isinstance(self.action_space, gym.spaces.Box) and self.action_space.shape[0] == 1:
            action = action.item()

        # Optional action transformation
        if self._transform_action_space_fn is not None:
            action = self._transform_action_space_fn(action)

        obs, reward, done, info = self.env.step(action)
        info["baseline"] = {"policy_action": action, "env_reward": reward}

        return obs, reward, done, info