from .static_env import StaticEnv
import numpy as np
import gymnasium as gym


class Cheetah(StaticEnv):
    """
    Hopper environment wrapped in the StaticEnv template.

    This class adapts the standard Gymnasium Hopper-v5 environment to fit the
    structure of the StaticEnv base class. Since the environment's physics are
    handled by the MuJoCo simulator and not defined analytically, we call the
    underlying environment's `step()` method within `_next_obs()` and cache the
    results (reward, done) to be returned by the other private methods.
    """
    def __init__(self):
        # Initialize the underlying Gymnasium environment
        self.env = gym.make("HalfCheetah-v5", render_mode="rgb_array")
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space

        # Call the parent constructor with details from the Hopper env
        super().__init__(
            obs_dim=self.observation_space.shape[0],
            act_dim=self.action_space.shape[0],
            obs_scale_low=self.observation_space.low,
            obs_scale_high=self.observation_space.high,
            act_scale_low=self.action_space.low,
            act_scale_high=self.action_space.high,
            num_steps=1000
        )

        # Caching variables to store results from the last step
        self._cached_reward = 0.
        self._cached_done = False
        
        # Training-related parameters (similar to the Road2d example)
        self.kwargs = dict(ret_low=0., ret_high=1., td3_unsafe_penalty=-20., eval_freq=5000)
        self.num_steps_to_train = 300000

    

    def _init_obs(self):
        """
        Resets the underlying environment and returns the initial observation.
        """
        obs, _ = self.env.reset()
        return obs

    def _next_obs(self, obs, act):
        """
        Executes a step in the underlying environment to get the next state.
        Caches the reward and done flag for subsequent calls.
        """
        next_obs, reward, terminated, truncated, _ = self.env.step(act)
        self._cached_reward = reward
        self._cached_done = terminated or truncated
        
        if not self._safe(next_obs):
            self._cached_done = True
            self._cached_reward -= 100.0  # Penalty for unsafe state
        return next_obs

    def _reward(self, obs, act, next_obs):
        """
        Returns the cached reward from the most recent step.
        """
        return self._cached_reward

    def _backup(self, obs):
        """
        A simple backup policy that returns a zero-action (do nothing).
        This can be used as a failsafe controller.
        """
        return np.zeros(self.action_space.shape)

    def _done(self, obs):
        """
        Returns the cached done flag from the most recent step.
        """
        return self._cached_done

    def _safe(self, obs):
        """
        Checks if the given observation is within the predefined safety constraints.
        The state is safe if obs satisfies Ax <= b for the safety polytope.
        """
        # Note: The original code had a loop that would only check the first
        # set of polytopes. We replicate that behavior here assuming it's intended.
        safe =  -2.8795 <= obs[8] <= 2.8795
        return safe
    
    def is_safe(self, obs):
        """
        Checks if the given observation is within the predefined safety constraints.
        The state is safe if obs satisfies Ax <= b for the safety polytope.
        """
        # Note: The original code had a loop that would only check the first
        # set of polytopes. We replicate that behavior here assuming it's intended.
        safe =  -2.8795 <= obs[8] <= 2.8795
        return safe

    def _stable(self, obs):
        """
        Defines stability as being equivalent to safety.
        """
        return self._safe(obs)

    # --- Passthrough methods for gym.Env compatibility ---
    def render(self):
        """
        Renders the environment.
        """
        return self.env.render()

    def close(self):
        """
        Closes the environment and releases resources.
        """
        return self.env.close()
