# Original implementation: https://github.com/abaisero/asym-rlpo
#
####
#
# Extended to informed POMDPs by anonymous authors (2025)
#
####

from typing import List

import gym
import gym.spaces
import numpy as np

from asym_rlpo.utils.debugging import checkraise


class IndexWrapper(gym.ObservationWrapper):
    """IndexWrapper.

    Takes a gym.Env with a flat Box observation space, and filters such that
    only the dimensions indicated by `indices` are observable.
    """

    def __init__(self, env: gym.Env, indices_obs: List[int], indices_info: List[int]):
        checkraise(
            isinstance(env.observation_space, gym.spaces.Box)
            and len(env.observation_space.shape) == 1,
            ValueError,
            'env.observation_space must be flat Box',
        )

        checkraise(
            len(set(indices_obs)) == len(indices_obs),
            ValueError,
            'indices must be unique',
        )

        assert isinstance(env.observation_space, gym.spaces.Box)
        checkraise(
            len(indices_obs) <= env.observation_space.shape[0],
            ValueError,
            'number of indices must not exceed state dimensions',
        )

        checkraise(
            min(indices_obs) >= 0,
            ValueError,
            'indices must be non-negative',
        )

        checkraise(
            max(indices_obs) < env.observation_space.shape[0],
            ValueError,
            'indices must be lower than state dimensions',
        )

        super().__init__(env)

        self._indices_obs = indices_obs
        self._indices_info = indices_info
        self.state_space = env.observation_space
        self.observation_space = gym.spaces.Box(
            env.observation_space.low[self._indices_obs],
            env.observation_space.high[self._indices_obs],
        )
        
        # Information space
        self.information_space =  gym.spaces.Box(
            env.observation_space.low[self._indices_info],
            env.observation_space.high[self._indices_info],
        )

        self.state: np.ndarray

    def observation(self, observation):
        self.state = observation
        return observation[self._indices_obs]

    # Return information
    def information(self):
        return self.state[self._indices_info]

