import gymnasium as gym
import numpy as np
class ObsIndexSelectionWrapper(gym.Wrapper):
    def __init__(self, env, video_delta_indices, state_delta_indices):
        super().__init__(env)
        self.video_delta_indices = video_delta_indices
        self.video_horizon = len(video_delta_indices)
        self.assert_delta_indices(self.video_delta_indices, self.video_horizon)
        if state_delta_indices is not None:
            self.state_delta_indices = state_delta_indices
            self.state_horizon = len(state_delta_indices)
            self.assert_delta_indices(self.state_delta_indices, self.state_horizon)
        else:
            self.state_horizon = None
            self.state_delta_indices = None
        self._observation_space = self.convert_observation_space(
            self.observation_space,
            self.video_horizon,
            self.state_horizon,
        )
    def assert_delta_indices(self, delta_indices: np.ndarray, horizon: int):
        assert len(delta_indices) == horizon, f"{delta_indices=}, {horizon=}"
        assert np.all(delta_indices <= 0), f"{delta_indices=}"
        assert delta_indices[-1] == 0, f"{delta_indices=}"
        if len(delta_indices) > 1:
            assert np.all(
                np.diff(delta_indices) == delta_indices[1] - delta_indices[0]
            ), f"{delta_indices=}"
            assert (delta_indices[1] - delta_indices[0]) > 0, f"{delta_indices=}"
    def select_steps_for_values(self, data_value, delta_indices):
        """
        data_value: [L, ...]
        delta_indices: np.ndarray[int], please check `assert_delta_indices` to see the requirements
        """
        L = data_value.shape[0]
        assert L >= len(delta_indices), f"{L=}, {len(delta_indices)=}"
        selected_indices = (L - 1) + delta_indices
        assert selected_indices[0] >= 0, f"{L=}, {selected_indices=}"
        return data_value[selected_indices]
    def select_steps_for_obs(self, obs):
        new_obs = {}
        for k in obs.keys():
            if k.startswith("video"):
                new_obs[k] = self.select_steps_for_values(obs[k], self.video_delta_indices)
            elif k.startswith("state"):
                if self.state_delta_indices is not None:
                    new_obs[k] = self.select_steps_for_values(obs[k], self.state_delta_indices)
                else:
                    continue
            else:
                raise ValueError(f"Unknown key: {k}")
        return new_obs
    def convert_observation_space(self, observation_space, video_horizon, state_horizon):
        new_observation_space = {}
        for k in observation_space.keys():
            box = observation_space[k]
            if k.startswith("video"):
                horizon = video_horizon
            elif k.startswith("state"):
                if state_horizon is not None:
                    horizon = state_horizon
                else:
                    continue
            else:
                raise ValueError(f"Unknown key: {k}")
            new_observation_space[k] = gym.spaces.Box(
                low=box.low[:horizon],
                high=box.high[:horizon],
                shape=(horizon, *box.shape[1:]),
                dtype=box.dtype,
            )
        return gym.spaces.Dict(new_observation_space)
    def reset(self, seed=None, options=None):
        obs, info = super().reset(seed=seed, options=options)
        obs = self.select_steps_for_obs(obs)
        return obs, info
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)
        obs = self.select_steps_for_obs(obs)
        return obs, reward, terminated, truncated, info
