import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import vmap # Import vmap

# # Fallback or explicit import from jax.ops (Keep as is)
# _segment_max_func = None # Initialize
# from jax.lax import segment_max as _segment_max_func_lax
# _segment_max_func = _segment_max_func_lax
from jax.ops import segment_max as _segment_max_func





from gymnax.environments import environment, spaces
from typing import Tuple, Optional, Dict, Any
import chex
from flax import struct

# Direct imports (Keep as is)
import brax
from brax import base
from brax.mjx import pipeline as mjx_pipeline
from brax.io import mjcf

from functools import partial
import time
import os


# --- Constants ---
_LIDAR_NUM_BINS_CONST = 16
_ACTION_REPEAT_CONST = 20 # Adjusted for faster timestep 0.002
num_hazards = 8
# ==================== XML Definition ====================
# Hybrid XML: PointGoal structure + First XML Physics/Actuators (2)
POINTGOAL_MJCF_XML_FIRST_PHYSICS_2ACT = """
<mujoco model="point_goal_first_physics_2act">
  <option timestep="0.005" gravity="0 0 0"/> <default>
    <geom condim="6" density="1"/>
    <joint damping=".001"/>
    <motor ctrlrange="-1 1" ctrllimited="true" forcerange="-0.5 0.5" forcelimited="true"/>
    <velocity ctrlrange="-1 1" ctrllimited="true" forcerange="-0.5 0.5" forcelimited="true"/>
    <site size="0.032" type="sphere"/>
  </default>

  <worldbody>
    <light pos="0 0 5" diffuse="1 1 1"/>
    <body name="agent" pos="0 0 0.1">
      <joint name="root_x" type="slide" axis="1 0 0" limited="false" damping="0.01"/> 
      <joint name="root_y" type="slide" axis="0 1 0" limited="false" damping="0.01"/> 
      <joint name="root_z" type="hinge" axis="0 0 1" limited="false" damping="0.005"/>
      <inertial pos="0 0 0" mass="1" diaginertia="0.01 0.01 0.01"/>
      <geom name="agent_sphere" type="sphere" size="0.1" rgba=".8 .2 .2 1" density="1"/>
      <geom name="agent_marker" type="box" size="0.05 0.05 0.05" pos="0.1 0 0" rgba=".8 .8 .2 1" contype="0" conaffinity="0"/>
      <site name="agent_site" pos="0 0 0" size="0.01"/>
    </body>
    <!-- 
    <geom name="goal_geom" type="capsule" pos="0 0 0.1" size="0.3 0.05" rgba=".2 .8 .2 0.5" contype="0" conaffinity="0"/>
    <geom name="hazard_geom_0" type="capsule" pos="-0.8 -1.1 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_1" type="capsule" pos="1.2 -0.9 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_2" type="capsule" pos="-1.1 0.7 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_3" type="capsule" pos="0.9 1.3 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_4" type="capsule" pos="0.1 -1.2 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_5" type="capsule" pos="-1.3 -0.1 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_6" type="capsule" pos="0.6 0.1 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    <geom name="hazard_geom_7" type="capsule" pos="-0.2 1.0 0.01" size="0.2 0.01" rgba=".5 .5 .5 1"/>
    -->
    </worldbody>

  <actuator>
    <motor name="act_x" joint="root_x" gear="10"/>
    <motor name="act_y" joint="root_y" gear="10"/> 
    <velocity name="act_z" joint="root_z" gear="10"/>
  </actuator>
</mujoco>
"""
# ===========================================================

# --- Fixed Positions (Keep as is) ---
_DEFAULT_GOAL_POS = jnp.array([0.0, 0.0], dtype=jnp.float32)
# Hazards positions are defined but num_hazards=0 means they aren't used by default
_DEFAULT_HAZARDS_POS = jnp.array([
     [-0.8, -1.1], [1.2, -0.9], [-1.1, 0.7], [0.9, 1.3],
     [0.1, -1.2], [-1.3, -0.1], [0.8, 0.1], [-0.2, 1.0]
 ], dtype=jnp.float32)

@struct.dataclass
class EnvParams:
    # --- Parameters kept from PointGoal, remove unused control params ---
    max_steps_in_episode: int = 1000 # Adjust if needed for new timestep (e.g., 2500 for same duration)
    agent_radius: float = 0.1
    goal_pos_fixed: jax.Array = struct.field(default_factory=lambda: _DEFAULT_GOAL_POS)
    goal_size: float = 0.3
    # Keep hazard definitions, but num_hazards controls usage
    hazards_positions_fixed: jax.Array = struct.field(default_factory=lambda: _DEFAULT_HAZARDS_POS)
    hazard_radius: float = 0.1
    # num_hazards: int = 1 # Set to 8 to use the defined hazards
    boundary_size: float = 1.5
    g_goal_val: float = -300.0
    g_dist_scale: float = 100.0
    h_hazard_val: float = 300.0
    h_safe_val: float = -300.0
    lidar_max_dist: float = 3.0

    # Indices kept as is
    agent_qpos_idx: int = 0
    agent_qvel_idx: int = 0
    agent_body_idx: int = 1
    num_agent_qpos: int = 3
    num_agent_qvel: int = 3


@struct.dataclass
class EnvState:
    # --- Kept exactly as is ---
    pipeline_state: Optional[base.State]
    g: jnp.ndarray = jnp.array(0.0, dtype=jnp.float32)
    h: jnp.ndarray = jnp.array(0.0, dtype=jnp.float32)
    time: jnp.ndarray = jnp.array(0, dtype=jnp.int32)

class PointGoalJax(environment.Environment): # New class name

    def __init__(self, backend='mjx'):
        super().__init__()
        # ==================== Load Hybrid XML ====================
        self.sys = mjcf.loads(POINTGOAL_MJCF_XML_FIRST_PHYSICS_2ACT)
        # ===========================================================

        # Verify actuator count is now 2
        if self.sys.nu != 3:
             raise ValueError(f"Hybrid XML should have 2 actuators, but model loaded {self.sys.nu}")
        print(f"Loaded model actuators: {self.sys.nu}")

        print(f"Actuator ctrlrange from loaded model: {self.sys.actuator_ctrlrange}")

        if backend == 'mjx':
            self.pipeline_init = jax.jit(partial(mjx_pipeline.init, self.sys))
            self.pipeline_step = jax.jit(partial(mjx_pipeline.step, self.sys))
        else:
            raise NotImplementedError(f"{backend} Not Implemented")

        # Observation shape kept as is (PointGoal's obs structure)
        self._obs_shape = (12 + _LIDAR_NUM_BINS_CONST * 2,)
        # Action shape is 2D (independent of actuator count)
        self._act_shape = (2,)

        # Indices kept as is
        self._agent_qpos_idx = 0
        self._agent_qvel_idx = 0
        self._agent_body_idx = 1
        self._num_agent_qpos = 3
        self._num_agent_qvel = 3
        self._segment_max = _segment_max_func

    @property
    def default_params(self) -> EnvParams:
        
        return EnvParams()

    # ==================== STEP FUNCTION (Direct Mapping) ====================
    @partial(jax.jit, static_argnums=(0,))
    def step(
        self, key: chex.PRNGKey, state: EnvState, action: jax.Array, params: EnvParams,
    ) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, dict]:
       
        
        # Ensure action is float32
        action = jnp.array(action, dtype=jnp.float32)

        # 1. Get agent's current yaw
        agent_yaw = state.pipeline_state.q[self._agent_qpos_idx + 2]
        cos_yaw = jnp.cos(agent_yaw)
        sin_yaw = jnp.sin(agent_yaw)

        # 2. Decompose 2D action
        forward_action = action[0]  # Controls forward/backward movement magnitude
        turn_action = action[1]     # Controls turning speed/direction

        # 3. Calculate world-frame X, Y, Z control components *without* explicit scaling factors
        #    The magnitude now directly depends on the action value passed through tanh.
        #    If movement is too slow/fast, consider adjusting 'gear' values in the MJCF XML.
        ctrl_x_world = forward_action * cos_yaw 
        ctrl_y_world = forward_action * sin_yaw 
        ctrl_z_rot = turn_action 
        
        # 4. Combine into 3D raw control vector (assuming actuator order: x, y, z)
        ctrl_raw = jnp.array([ctrl_x_world, ctrl_y_world, ctrl_z_rot], dtype=jnp.float32)

        # 5. Apply activation function (e.g., tanh) to map raw control to [-1, 1] range
        #    This step remains important to bound the control signal sent to the actuators.
        ctrl = jnp.tanh(ctrl_raw) 

        # --- Physics simulation with action repeat ---
        def scan_step(carry_state, _):
            # Pass the calculated 3D control vector to the physics step
            new_pipeline_state = self.pipeline_step(carry_state, ctrl) 
            return new_pipeline_state, None

        # Run the simulation steps
        final_pipeline_state, _ = lax.scan(
            scan_step, state.pipeline_state, (), length=_ACTION_REPEAT_CONST
        )

        # --- State updates and checks (Boundary, Goal, Hazards) ---
        final_agent_pos_xy = final_pipeline_state.q[self._agent_qpos_idx : self._agent_qpos_idx + 2]
        is_out_of_bounds = jnp.any(jnp.abs(final_agent_pos_xy) > params.boundary_size)
        
        def reset_state_with_zero_vel():
            new_qpos = state.pipeline_state.q  
            new_qvel = jnp.zeros_like(state.pipeline_state.qvel)  
            return self.pipeline_init(new_qpos, new_qvel)
        
        final_pipeline_state = lax.cond(
            is_out_of_bounds,
            reset_state_with_zero_vel, 
            lambda: final_pipeline_state   
        )
        
        final_agent_pos_xy = final_pipeline_state.q[self._agent_qpos_idx : self._agent_qpos_idx + 2]

        dist_to_goal_center = jnp.linalg.norm(final_agent_pos_xy - params.goal_pos_fixed)
        goal_reached = dist_to_goal_center <= params.goal_size
        dist_to_goal_boundary = jnp.maximum(0.0, dist_to_goal_center - params.goal_size)
        g_value = jnp.where(goal_reached, params.g_goal_val, params.g_dist_scale * dist_to_goal_boundary)

        hazards_positions_to_check = params.hazards_positions_fixed[:num_hazards]
        dist_sq_hazards = jnp.sum(jnp.square(hazards_positions_to_check - final_agent_pos_xy), axis=1)
        min_dist_sq_hazard = (params.agent_radius + params.hazard_radius)**2
        
        hazard_collision = lax.cond(
             num_hazards > 0,
             lambda: jnp.any(dist_sq_hazards < min_dist_sq_hazard), 
             lambda: jnp.array(False)                               
        )
        
        is_unsafe = jnp.logical_or(hazard_collision, is_out_of_bounds)
        
        h_value = jnp.where(is_unsafe, params.h_hazard_val, params.h_safe_val)

        # --- Update state, reward, observation ---
        next_time = state.time + 1
        next_state = EnvState(
            pipeline_state=final_pipeline_state, g=g_value, h=h_value, time=next_time
        )

        # --- Reward Function ---
        # Example: Penalize control effort (energy) and optionally reward goal proximity
        energy_consumption = 3 * jnp.sum(jnp.square(ctrl)) # Penalty for control magnitude (3D ctrl)
        reward = energy_consumption
        # reward = reward_ctrl_penalty # Simpler energy penalty only

        # Get observation based on the final state
        obs = self._get_obs(final_pipeline_state, params) 
        # Check if episode is done
        done = (next_time >= params.max_steps_in_episode)

        # --- Info Dictionary ---
        info = {
            "dist_to_goal": dist_to_goal_center, 
            "goal_reached": goal_reached,
            "hazard_collision": hazard_collision, 
            "boundary_violation_attempted": is_out_of_bounds,
            "raw_action": action,              # The original 2D action input
            "applied_control": ctrl,           # The final 3D control vector applied
            "agent_pos": final_agent_pos_xy,
            "agent_yaw": final_pipeline_state.q[self._agent_qpos_idx + 2],
            "g_value": g_value, 
            "h_value": h_value,
        }
        # Return values, stopping gradients for obs and state
        return lax.stop_gradient(obs), lax.stop_gradient(next_state), reward, done, info
    # ==================== END STEP FUNCTION ====================

    # @partial(jax.jit, static_argnums=(0,))
    # def reset(
    #     self, key: chex.PRNGKey, params: EnvParams
    # ) -> Tuple[chex.Array, EnvState]:
    #     rng_pos, rng_yaw, rng_fallback = jax.random.split(key, 3)
        
    #     safe_margin = 0.2
        
    #     initial_agent_pos_xy = jax.random.uniform(
    #         rng_pos, (2,),
    #         minval=-params.boundary_size + safe_margin, 
    #         maxval=params.boundary_size - safe_margin, 
    #         dtype=jnp.float32
    #     )
        
    #     def is_safe_from_hazards(pos):
    #         if num_hazards <= 0:
    #             return True
            
    #         hazards_to_check = params.hazards_positions_fixed[:num_hazards]
    #         dist_sq_hazards = jnp.sum(jnp.square(hazards_to_check - pos), axis=1)
    #         min_safe_dist_sq = (params.agent_radius + params.hazard_radius + 0.1)**2
    #         return jnp.all(dist_sq_hazards >= min_safe_dist_sq)
        
    #     is_pos_safe = is_safe_from_hazards(initial_agent_pos_xy)
        
    #     fallback_pos = jax.random.uniform(
    #         rng_fallback, (2,),
    #         minval=jnp.array([-0.75, -0.75], dtype=jnp.float32),
    #         maxval=jnp.array([0.25, 0.25], dtype=jnp.float32),
    #         dtype=jnp.float32
    #     )
        
    #     final_agent_pos_xy = lax.cond(
    #         is_pos_safe,
    #         lambda: initial_agent_pos_xy,
    #         lambda: fallback_pos
    #     )
        
    #     init_yaw = jax.random.uniform(rng_yaw, (), minval=-jnp.pi, maxval=jnp.pi, dtype=jnp.float32)
        
    #     qpos = self.sys.init_q.at[self._agent_qpos_idx : self._agent_qpos_idx + self._num_agent_qpos].set(
    #         jnp.array([final_agent_pos_xy[0], final_agent_pos_xy[1], init_yaw])
    #     )
        
        
    #     qvel = jnp.zeros_like(self.sys.init_q)
        
      
    #     pipeline_state = self.pipeline_init(qpos, qvel)
        
      
    #     dist_to_goal_center_init = jnp.linalg.norm(final_agent_pos_xy - params.goal_pos_fixed)
    #     goal_reached_init = dist_to_goal_center_init <= params.goal_size
    #     dist_to_goal_boundary_init = jnp.maximum(0.0, dist_to_goal_center_init - params.goal_size)
    #     g_value_init = jnp.where(goal_reached_init, params.g_goal_val, params.g_dist_scale * dist_to_goal_boundary_init)
        
      
    #     hazards_positions_to_check_init = params.hazards_positions_fixed[:num_hazards]
    #     dist_sq_hazards_init = jnp.sum(jnp.square(hazards_positions_to_check_init - final_agent_pos_xy), axis=1)
    #     min_dist_sq_hazard_init = (params.agent_radius + params.hazard_radius)**2
    #     hazard_collision_init = lax.cond(
    #          num_hazards > 0,
    #          lambda: jnp.any(dist_sq_hazards_init < min_dist_sq_hazard_init),
    #          lambda: jnp.array(False)
    #     )
        
       
    #     is_out_of_bounds_init = jnp.any(jnp.abs(final_agent_pos_xy) > params.boundary_size)
        
      
    #     is_unsafe_init = jnp.logical_or(hazard_collision_init, is_out_of_bounds_init)
    #     h_value_init = jnp.where(is_unsafe_init, params.h_hazard_val, params.h_safe_val)
        
       
    #     state = EnvState(
    #         pipeline_state=pipeline_state, g=g_value_init, h=h_value_init, time=jnp.array(0, dtype=jnp.int32)
    #     )

    #     obs = self._get_obs(pipeline_state, params)
    #     return lax.stop_gradient(obs), lax.stop_gradient(state)

    def reset(
        self, key: chex.PRNGKey, params: EnvParams
    ) -> Tuple[chex.Array, EnvState]:
        rng_pos, rng_yaw = jax.random.split(key, 2)

        safe_margin = 0.2

        # Only sample inside boundary with margin (no fallback)
        final_agent_pos_xy = jax.random.uniform(
            rng_pos, (2,),
            minval=-params.boundary_size + safe_margin,
            maxval=params.boundary_size - safe_margin,
            dtype=jnp.float32
        )

        init_yaw = jax.random.uniform(
            rng_yaw, (), minval=-jnp.pi, maxval=jnp.pi, dtype=jnp.float32
        )

        qpos = self.sys.init_q.at[
            self._agent_qpos_idx : self._agent_qpos_idx + self._num_agent_qpos
        ].set(jnp.array([final_agent_pos_xy[0], final_agent_pos_xy[1], init_yaw], dtype=jnp.float32))

        qvel = jnp.zeros_like(self.sys.init_q)

        pipeline_state = self.pipeline_init(qpos, qvel)

        # g init (goal-related)
        dist_to_goal_center_init = jnp.linalg.norm(final_agent_pos_xy - params.goal_pos_fixed)
        goal_reached_init = dist_to_goal_center_init <= params.goal_size
        dist_to_goal_boundary_init = jnp.maximum(0.0, dist_to_goal_center_init - params.goal_size)
        g_value_init = jnp.where(
            goal_reached_init,
            params.g_goal_val,
            params.g_dist_scale * dist_to_goal_boundary_init
        )

        # h init (hazard/bounds-related)
        hazards_positions_to_check_init = params.hazards_positions_fixed[:num_hazards]
        dist_sq_hazards_init = jnp.sum(jnp.square(hazards_positions_to_check_init - final_agent_pos_xy), axis=1)
        min_dist_sq_hazard_init = (params.agent_radius + params.hazard_radius)**2
        hazard_collision_init = lax.cond(
            num_hazards > 0,
            lambda: jnp.any(dist_sq_hazards_init < min_dist_sq_hazard_init),
            lambda: jnp.array(False)
        )

        # Note: with safe_margin, this should almost always be False, but keep for safety
        is_out_of_bounds_init = jnp.any(jnp.abs(final_agent_pos_xy) > params.boundary_size)

        is_unsafe_init = jnp.logical_or(hazard_collision_init, is_out_of_bounds_init)
        h_value_init = jnp.where(is_unsafe_init, params.h_hazard_val, params.h_safe_val)

        state = EnvState(
            pipeline_state=pipeline_state,
            g=g_value_init,
            h=h_value_init,
            time=jnp.array(0, dtype=jnp.int32)
        )

        obs = self._get_obs(pipeline_state, params)
        return lax.stop_gradient(obs), lax.stop_gradient(state)

  
    def _get_obs(self, pipeline_state: base.State, params: EnvParams) -> jnp.ndarray:
        # Observation logic kept identical to original PointGoalJax
        # It uses qpos/qvel and lidar calculated from goal/hazard positions in params
        qpos = pipeline_state.q
        qvel = pipeline_state.qvel
        agent_pos_xy = qpos[self._agent_qpos_idx : self._agent_qpos_idx + 2]
        agent_yaw = qpos[self._agent_qpos_idx + 2]
        vx_joint, vy_joint = qvel[self._agent_qvel_idx], qvel[self._agent_qvel_idx + 1]
        wz_hinge = qvel[self._agent_qvel_idx + 2]
        cos_yaw_obs, sin_yaw_obs = jnp.cos(agent_yaw), jnp.sin(agent_yaw)
        velocimeter = jnp.array([vx_joint, vy_joint, 0.0], dtype=jnp.float32)
        gyro = jnp.array([0.0, 0.0, wz_hinge], dtype=jnp.float32)
        accelerometer = jnp.zeros(3, dtype=jnp.float32)
        magnetometer = jnp.array([cos_yaw_obs, -sin_yaw_obs, 0.0], dtype=jnp.float32)
        rot_matrix_world_to_local_T = jnp.array([[cos_yaw_obs, sin_yaw_obs], [-sin_yaw_obs, cos_yaw_obs]], dtype=jnp.float32)
        _lidar_max_dist = params.lidar_max_dist
        goal_lidar = self._get_lidar_observation(
            agent_pos_xy, rot_matrix_world_to_local_T, params.goal_pos_fixed.reshape(1, 2),
            params.goal_size, _lidar_max_dist, _LIDAR_NUM_BINS_CONST
        )
        # Use hazard positions from params (slice by num_hazards)
        hazards_positions_to_scan = params.hazards_positions_fixed[:num_hazards]
        hazards_lidar = self._get_lidar_observation(
            agent_pos_xy, rot_matrix_world_to_local_T, hazards_positions_to_scan,
            params.hazard_radius, _lidar_max_dist, _LIDAR_NUM_BINS_CONST
        )
        obs = jnp.concatenate([
            accelerometer, velocimeter, gyro, magnetometer, goal_lidar, hazards_lidar
        ], axis=0)
        obs = jnp.nan_to_num(obs)
        return obs.astype(jnp.float32)



    def _get_lidar_observation(
        self, agent_pos_xy: jnp.ndarray, rot_matrix_T: jnp.ndarray, obj_pos_world: jnp.ndarray,
        obj_radius: float, lidar_max_dist: float, lidar_num_bins: int
    ) -> jnp.ndarray:
        # Lidar logic kept identical to original PointGoalJax
        if self._segment_max is None:
            raise RuntimeError("segment_max not set")
        num_objs = obj_pos_world.shape[0]
        # Check added previously to handle 0 hazards correctly is kept
        if num_objs == 0:
             return jnp.zeros(lidar_num_bins, dtype=jnp.float32)
        # Calculation logic remains the same
        rel_pos_world = obj_pos_world - agent_pos_xy
        local_x = rel_pos_world[:, 0] * rot_matrix_T[0, 0] + rel_pos_world[:, 1] * rot_matrix_T[1, 0]
        local_y = rel_pos_world[:, 0] * rot_matrix_T[0, 1] + rel_pos_world[:, 1] * rot_matrix_T[1, 1]
        local_pos = jnp.stack([local_x, local_y], axis=-1)
        dist_to_center = jnp.linalg.norm(local_pos, axis=1)
        dist_to_surface = jnp.maximum(0.0, dist_to_center - obj_radius)
        angles = jnp.arctan2(local_pos[:, 1], local_pos[:, 0])
        bin_size = 2 * jnp.pi / lidar_num_bins
        bin_indices = jnp.floor((angles + jnp.pi) / bin_size).astype(jnp.int32)
        bin_indices = jnp.clip(bin_indices, 0, lidar_num_bins - 1)
        clipped_dist_to_surface = jnp.minimum(dist_to_surface, lidar_max_dist)
        normalized_dist = clipped_dist_to_surface / lidar_max_dist
        lidar_values = 1.0 - normalized_dist
        lidar_values = jnp.maximum(0.0, lidar_values)
        lidar = self._segment_max(
            data=lidar_values,
            segment_ids=bin_indices,
            num_segments=lidar_num_bins,
        )
        lidar = jnp.where(lidar == -jnp.inf, 0.0, lidar)
        return lidar.astype(jnp.float32)
   

    
    @property
    def name(self) -> str: return "PointGoal_Hybrid_Direct_2Act" # New name
    @property
    def num_actions(self) -> int: return self._act_shape[0] # Still 2

    def action_space(self, params: Optional[EnvParams] = None) -> spaces.Box:
        # Action space def kept identical
        low = jnp.full(self._act_shape, -jnp.inf, dtype=jnp.float32)
        high = jnp.full(self._act_shape, jnp.inf, dtype=jnp.float32)
        return spaces.Box(low=low, high=high, shape=self._act_shape, dtype=jnp.float32)

    def observation_space(self, params: EnvParams) -> spaces.Box:
        # Observation space def kept identical
        low = jnp.full(self._obs_shape, -jnp.inf, dtype=jnp.float32)
        high = jnp.full(self._obs_shape, jnp.inf, dtype=jnp.float32)
        return spaces.Box(low=low, high=high, shape=self._obs_shape, dtype=jnp.float32)

    def state_space(self, params: EnvParams) -> spaces.Dict:
        # State space def kept identical (with g/h)
        return spaces.Dict({
            "g": spaces.Box(-jnp.inf, jnp.inf, (), dtype=jnp.float32),
            "h": spaces.Box(-jnp.inf, jnp.inf, (), dtype=jnp.float32),
            "time": spaces.Discrete(params.max_steps_in_episode + 1),
        })
   
