import gymnasium as gym
from gymnasium.envs.mujoco.walker2d_v5 import Walker2dEnv

class Walker2dGrad(Walker2dEnv):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.final_pos = 0
        self.control_ok = True
        self.max_horizon = 100
        self.steps_taken = 0

    def reset(self, seed=None, options=None):
        observation, info = super().reset(seed=seed)
        self.final_pos = 0
        self.control_ok = True
        self.steps_taken = 0
        return observation, info
        
    def step(self, action):
        # Get the original step result
        self.steps_taken += 1
        observation, reward, terminated, truncated, info = super().step(action)

        if not info['reward_ctrl'] > -0.006:
            self.control_ok = False
        self.final_pos = info['x_position']
        info['control_ok'] = self.control_ok

        reward = 0
        if self.control_ok and self.steps_taken >= self.max_horizon:
            # scale forward_reward within [0,1]
            if self.final_pos < 0.4:
                reward = 0
            elif self.final_pos > 0.8:
                reward = 1
            else:
                reward = (self.final_pos - 0.4) / 0.4
            
        if self.steps_taken >= self.max_horizon:
            if self.final_pos > 0.8:
                info['pos_ok'] = True
            else:
                info['pos_ok'] = False

        return observation, reward, terminated, truncated, info

# Register the environment
gym.register(
    id="Walker2dGrad-v1",
    entry_point="stable_baselines3.envs.Walker2dGrad:Walker2dGrad",
    max_episode_steps=100,
) 