"""Wraps the terrain-mass environment to output a `ground_truth_mode`
value in the `info` dict in each step, corresponding to the
terrain in which the mass is positioned.
"""
from gymnasium.envs.registration import register
from gymnasium.envs.registration import WrapperSpec
from terrain_mass.gymnasium import TerrainMassEnv

from swmpo_experiments.terrain_mass_utils.ground_truth_states import get_ground_truth_state


class GroundTruthTerrainMassEnv(TerrainMassEnv):
    def reset(self, *args, **kwargs):
        obs, info = super().reset(*args, **kwargs)
        info["ground_truth_mode"] = get_ground_truth_state(
            environment_instance=self.instance,
            state=self.current_state,
        )
        return obs, info

    def step(self, *args, **kwargs):
        obs, reward, terminated, truncated, info = super().step(
            *args,
            **kwargs,
        )
        info["ground_truth_mode"] = get_ground_truth_state(
            environment_instance=self.instance,
            state=self.current_state,
        )
        return obs, reward, terminated, truncated, info


register(
    id="GroundTruthTerrainMass-v0",
    entry_point="swmpo_experiments.terrain_mass_utils.ground_truth_wrapper:GroundTruthTerrainMassEnv",
    additional_wrappers=(
        WrapperSpec(
            name=GroundTruthTerrainMassEnv.__name__,
            entry_point="swmpo_experiments.ground_truth_wrapper:GroundTruthWrapper",
            kwargs=dict(
                extrinsic_reward_scale=1.0,
                mode_n=2,
                exploration_window_size=60,
            ),
        ),
    ),
    kwargs=dict(),
)
