import gymnasium as gym
from gymnasium.envs.mujoco.hopper_v5 import HopperEnv

class HopperGrad(HopperEnv):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.final_pos = 0
        self.control_ok = True
        self.max_horizon = 100
        self.steps_taken = 0

    def reset(self, seed=None, options=None):
        observation, info = super().reset(seed=seed)
        self.final_pos = 0
        self.control_ok = True
        self.steps_taken = 0
        return observation, info
        
    def step(self, action):
        # Get the original step result
        self.steps_taken += 1
        observation, reward, terminated, truncated, info = super().step(action)

        if not info['reward_ctrl'] > -0.004:
            self.control_ok = False
        self.final_pos = info['x_position']
        info['control_ok'] = self.control_ok

        reward = 0
        if self.control_ok and self.steps_taken >= self.max_horizon:
            # scale forward_reward within [0,1]
            if self.final_pos < 0.4:
                reward = 0
            elif self.final_pos > 0.8:
                reward = 1
            else:
                reward = (self.final_pos - 0.4) / 0.4
            
        if self.steps_taken >= self.max_horizon:
            if self.final_pos > 0.8:
                info['pos_ok'] = True
            else:
                info['pos_ok'] = False

        return observation, reward, terminated, truncated, info

# Register the environment
gym.register(
    id="HopperGrad-v1",
    entry_point="stable_baselines3.envs.HopperGrad:HopperGrad",
    max_episode_steps=100,
) 