import gymnasium as gym
from gymnasium.envs.mujoco.half_cheetah_v5 import HalfCheetahEnv

class HalfCheetahGrad(HalfCheetahEnv):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.final_pos = 0
        self.control_ok = True
        self.max_horizon = 100
        self.steps_taken = 0

    def reset(self, seed=None, options=None):
        observation, info = super().reset(seed=seed)
        self.final_pos = 0
        self.control_ok = True
        self.steps_taken = 0
        return observation, info
        
    def step(self, action):
        # Get the original step result
        self.steps_taken += 1
        observation, reward, terminated, truncated, info = super().step(action)

        if not info['reward_ctrl'] > -0.6:
            self.control_ok = False
        self.final_pos = info['x_position']
        info['control_ok'] = self.control_ok

        reward = 0
        if self.control_ok and self.steps_taken >= self.max_horizon:
            # scale forward_reward within [0,1]
            if self.final_pos < 2.5:
                reward = 0
            elif self.final_pos > 5:
                reward = 1
            else:
                reward = (self.final_pos - 2.5) / 2.5

        if self.steps_taken >= self.max_horizon:
            if self.final_pos > 5:
                info['pos_ok'] = True
            else:
                info['pos_ok'] = False

        return observation, reward, terminated, truncated, info

# Register the environment
gym.register(
    id="HalfCheetahGrad-v1",
    entry_point="stable_baselines3.envs.HalfCheetahGrad:HalfCheetahGrad",
    max_episode_steps=100,
) 