from dataclasses import dataclass
import numpy as np, time, os, mujoco as mj
import gymnasium as gym
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.envs.mujoco.ant_v4 import AntEnv
from gymnasium.spaces import Box
from gymnasium import utils
from warnings import filterwarnings
from utils import *
from dotenv import load_dotenv
from gymnasium.envs.registration import register
from gymnasium import spaces


# Load environment variables and filter warnings
load_dotenv()
filterwarnings(action="ignore", category=DeprecationWarning, 
               message="`np.bool8` is a deprecated alias for `np.bool_`")

# Register all environments
env_configs = [
    ('ant', 'Ant'),
    ('ant_bounded', 'BoundedSquareAnt'),
    ('ant_px', 'AntPosX'),
    ('ant_nx', 'AntNegX'),
    ('ant_py', 'AntPosY'),
    ('ant_ny', 'AntNegY'),
    ('antr_px', 'AntRPosX'),
    ('antr_nx', 'AntRNegX'),
    ('antr_py', 'AntRPosY'),
    ('antr_ny', 'AntRNegY'),
    ('ant_food', 'AntFood'),
    ('ant_goal', 'AntGoal'),
    ('predator_prey', 'AntPredatorPrey'),
    ('random_predator_prey', 'AntRandomPredatorPrey'),
    ('ant_corridor_predator_prey', 'AntCorridorPredatorPrey'),
    ('ant_circle_predator_prey', 'AntCirclePredatorPrey'),
    ('random_terran_ant', 'RandomTerranAnt'),
    ('random_terran_predator_prey', 'AntRandomTerranPredatorPrey'),
    ('random_terran_predator_prey_forward', 'AntRandomTerranPredatorPreyForward')
]

# Register all environments at once
for id, entry_point in env_configs:
    register(id=id, entry_point=f'envs:{entry_point}')

BASE_XML_DIR = os.getenv('BASE_XML_DIR') 


class Ant(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(Ant, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
        
    def step(self, action):
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()
        
        # Calculate reward using the implemented method
        reward = self.calculate_reward(xy_position_before, xy_position_after)
         
        position_torso = self.get_body_com("torso").copy()
        observation = self._get_obs()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "rewards": reward,
            "ant_x_position": position_torso[0],
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }
        
        if self.render_mode == "human":
            self.render()
            
        terminated = False
        contact_detected = self.data.sensordata[0]
        if contact_detected > 0.0:
            terminated = True
            
        return observation, reward, terminated, False, info

    def get_obs(self):
        return self._get_obs() 


class RandomTerranAnt(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(RandomTerranAnt, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant_random_terran.xml",
            **kwargs,
        )
        self.test_model = test_mode


class AntFood(AntEnv):
    def __init__(self, 
        test_mode=False, 
        food_distance=0.5,
        food_init=100,
        food_max=100,
        food_z=1.2,
        **kwargs
    ):
        super(AntFood, self).__init__(
            exclude_current_positions_from_observation=False,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant_food.xml",
            
            **kwargs,
        )
        self.food_source_position = np.array([0.0, 0.0, food_z])
        self.food = food_init
        self.food_init = food_init
        self.food_max = food_max
        self.food_distance = food_distance # Distance threshold for energy replenishment
        self.observation_space = Box(
            low=-np.inf, high=np.inf, shape=(30,), dtype=np.float64
        )
        logger.debug(f"{YELLOW}Energy: {self.food}{ENDC}")
        logger.debug(f"{YELLOW}Food source position: {self.food_source_position}{ENDC}")
        logger.debug(f"{YELLOW}Energy distance threshold: {self.food_distance}{ENDC}")
        logger.debug(f"{YELLOW}Food max energy: {self.food_max}{ENDC}")
        logger.debug(f"{YELLOW}Food initial energy: {self.food}{ENDC}")
        logger.debug(f"{YELLOW}Food energy z: {food_z}{ENDC}")
         
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
        
    def step(self, action):
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()
        reward = 0 
        self.food -= 1  # Consume energy at each step
        position_torso = self.get_body_com("torso").copy()
        distance_to_food = np.linalg.norm(position_torso - self.food_source_position)
           
        if distance_to_food < self.food_distance:
            # Replenish energy to max when near food
            self.food = self.food_max
        
        observation = self._get_obs()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "rewards": reward,
            "ant_x_position": position_torso[0],
            "ant_y_position": position_torso[1],
            "ant_z_position": position_torso[2],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
            "energy": self.food,  # Add energy to info dictionary
            "distance_to_food": distance_to_food,  # Add distance to food
        }
        
        if self.render_mode == "human":
            self.render()
        
        terminated = False
        contact_detected = self.data.sensordata[0]
        if contact_detected > 0.0 and position_torso[2] < 0.3:
            terminated = True
        
        if self.food <= 0:
            terminated = True   
        return observation, reward, terminated, False, info
        
    def reset_model(self):
        noise_low = 0 #-self._reset_noise_scale
        noise_high = 0 #self._reset_noise_scale
        
        qpos = self.init_qpos + self.np_random.uniform(
            low=noise_low, high=noise_high, size=self.model.nq
        )
        
        qvel = (
            self.init_qvel
            + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
        )
        
        # Reset energy to initial value
        self.food = self.food_init
        
        self.set_state(qpos, qvel)
        observation = self._get_obs()
        return observation
    
    def _get_obs(self):
        # Get the original observation
        original_obs = super()._get_obs()
        # Append energy to observation
        return np.append(original_obs, [self.food])
    
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        # Extract energy from the last element of obs
        self.food = float(obs[-1])
        
        # Extract qpos and qvel from the remaining elements
        remaining_obs = obs[:-1]
        qpos = remaining_obs[:self.model.nq]
        qvel = remaining_obs[self.model.nq:]
        
        # Set the physical state
        self.set_state(qpos, qvel)
        
        return self._get_obs()


class BoundedSquareAnt(AntEnv):
    def __init__(self, boundary_limit=1, test_mode=False, **kwargs):
        super(BoundedSquareAnt, self).__init__(
            exclude_current_positions_from_observation=False,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )
        print(f"Boundary limit: {boundary_limit}, test mode: {test_mode}")
        self.boundary_limit = boundary_limit
        self.test_model = test_mode
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
        
    def step(self, action):
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()
        
        # Calculate reward using the implemented method
        reward = self.calculate_reward(xy_position_before, xy_position_after)
        
        position_torso = self.get_body_com("torso").copy()
        observation = self._get_obs()
        
        # Check if agent is outside the boundary
        x, y = position_torso[0], position_torso[1]
        outside_boundary = abs(x) > self.boundary_limit or abs(y) > self.boundary_limit
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "rewards": reward,
            "ant_x_position": position_torso[0],
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }
        reward = 0 
        if self.render_mode == "human":
            self.render()
            
        # Check termination conditions
        terminated = False
        contact_detected = self.data.sensordata[0]
        if contact_detected > 0.0:
            terminated = True
        
        if outside_boundary and self.test_model:
            terminated = True
        
        return observation, reward, terminated, False, info

    def reset_model(self):
        noise_low = 0 #-self._reset_noise_scale
        noise_high = 0 #self._reset_noise_scale
        
        qpos = self.init_qpos + self.np_random.uniform(
            low=noise_low, high=noise_high, size=self.model.nq
        )
        
        qvel = (
            self.init_qvel
            + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
        )
        self.set_state(qpos, qvel)
        observation = self._get_obs()
        return observation
    
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()
    

class AntOrgDir(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(AntDir, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )

        self.step_count = 0
        self.max_steps = 1000
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
    
    def step(self, action):
        self.step_count += 1
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        dir_reward = self.calculate_reward(xy_position_before, xy_position_after)
        healthy_reward = self.healthy_reward

        rewards = dir_reward + healthy_reward

        costs = ctrl_cost = self.control_cost(action)

        terminated = self.terminated
        observation = self._get_obs()
        
        position_torso = self.get_body_com("torso").copy()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "reward_forward": dir_reward,
            "reward_ctrl": -ctrl_cost,
            "reward_survive": healthy_reward,
            "x_position": xy_position_after[0],
            "y_position": xy_position_after[1],
            "distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
            "x_velocity": x_velocity,
            "y_velocity": y_velocity,
            "forward_reward": dir_reward,
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }
        if self._use_contact_forces:
            contact_cost = self.contact_cost
            costs += contact_cost
            info["reward_ctrl"] = -contact_cost

        reward = rewards - costs

        if self.render_mode == "human":
            self.render()
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True
        terminated = terminated or truncated 
        return observation, reward, terminated, truncated, info
      
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.step_count = 0
        # noise_low = 0 #-self._reset_noise_scale
        # noise_high = 0 #self._reset_noise_scale
        
        # qpos = self.init_qpos + self.np_random.uniform(
        #     low=noise_low, high=noise_high, size=self.model.nq
        # )
        
        # qvel = (
        #     self.init_qvel
        #     + self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
        # )
        qpos = self.init_qpos
        qvel = self.init_qvel
        self.set_state(qpos, qvel)

        observation = self._get_obs()

        return observation


class AntDir(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(AntDir, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.2, 1.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )

        self.step_count = 0
        self.max_steps = 1000
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
    
    def step(self, action):
        self.step_count += 1
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        rewards = self.calculate_reward(xy_position_before, xy_position_after)

        terminated = self.terminated
        observation = self._get_obs()
        
        position_torso = self.get_body_com("torso").copy()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "reward_forward": rewards,
            "x_position": xy_position_after[0],
            "y_position": xy_position_after[1],
            "distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
            "x_velocity": x_velocity,
            "y_velocity": y_velocity,
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }

        if self.render_mode == "human":
            self.render()
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True
        terminated = terminated or truncated 
        return observation, rewards, terminated, truncated, info
      
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.step_count = 0
        qpos = self.init_qpos 
        qvel = self.init_qvel
        self.set_state(qpos, qvel)
        observation = self._get_obs()

        return observation


class AntPosX(AntDir):
    """Ant environment that rewards positive x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return x_velocity


class AntNegX(AntDir):
    """Ant environment that rewards negative x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -x_velocity


class AntPosY(AntDir):
    """Ant environment that rewards positive y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return y_velocity


class AntNegY(AntDir):
    """Ant environment that rewards negative y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -y_velocity




class AntRDir(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(AntRDir, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.2, 1.0),
            xml_file=f"{BASE_XML_DIR}/ant_random_terran.xml",
            **kwargs,
        )

        self.step_count = 0
        self.max_steps = 1000
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
    
    def step(self, action):
        self.step_count += 1
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        rewards = self.calculate_reward(xy_position_before, xy_position_after)

        terminated = self.terminated
        observation = self._get_obs()
        
        position_torso = self.get_body_com("torso").copy()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "reward_forward": rewards,
            "x_position": xy_position_after[0],
            "y_position": xy_position_after[1],
            "distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
            "x_velocity": x_velocity,
            "y_velocity": y_velocity,
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }

        if self.render_mode == "human":
            self.render()
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True
        terminated = terminated or truncated 
        return observation, rewards, terminated, truncated, info
      
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.step_count = 0
        qpos = self.init_qpos 
        qvel = self.init_qvel
        self.set_state(qpos, qvel)
        observation = self._get_obs()

        return observation


class AntRPosX(AntRDir):
    """Ant environment that rewards positive x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return x_velocity


class AntRNegX(AntRDir):
    """Ant environment that rewards negative x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -x_velocity


class AntRPosY(AntRDir):
    """Ant environment that rewards positive y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return y_velocity


class AntRNegY(AntRDir):
    """Ant environment that rewards negative y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -y_velocity

class MPEnv(gym.Wrapper):
    """
    Environment wrapper that uses manual indices as actions and runs them through the policy network
    """
    def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
        self.env = env
        self.policy_network = policy_network
        self.device = device
        from gymnasium import spaces 
        # Create discrete action space for manual indices
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.std = std
        self.hard = hard
        logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        self.obs = None
    
    def step(self, action):
        """
        Takes a discrete action (manual index), runs it through policy network,
        and executes the resulting continuous action in the environment
        """
        self.steps += 1
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs, axis=0), 
            dtype=torch.float32
        ).to(self.device)
        # Get continuous action from policy network using manual index
        with torch.no_grad():
            a, logp_a, pi, mu, std, cov, *_ = self.policy_network(
                obs_tensor,
                deterministic=True if self.hard else False,
                bias_config=None if self.hard else {"std": self.std},
                manual_indices=torch.tensor([action]).to(self.device)
            )
            
            
        # Execute continuous action in environment
        continuous_action = a.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        self.obs = observation
        return observation, rewards, terminated, truncated, info 

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
    


# class MPAntMazeEnv(gym.Wrapper):
#     """
#     Environment wrapper that uses manual indices as actions and runs them through the policy network.
#     Returns a concatenated observation with achieved_goal and desired_goal at the beginning,
#     but only passes the raw observation to the policy network.
#     """
#     def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
#         super().__init__(env)
#         self.policy_network = policy_network
#         self.device = device
#         from gymnasium import spaces
#         import numpy as np
#         print(f"num_actions: {num_actions}") 
#         # Create discrete action space for manual indices
#         self.action_space = spaces.Discrete(num_actions)
#         self.observation_space = spaces.Box(
#             low=-np.inf, high=np.inf, shape=(31,), dtype=np.float64
#         )
#         self.steps = 0
#         self.std = std
#         self.hard = hard
#         logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}") 
    
#     def step(self, action):
#         self.steps += 1
#         # Convert raw observation to tensor for policy network (without goals)
#         obs_tensor = torch.as_tensor(
#             np.expand_dims(self.obs[:27], axis=0), 
#             dtype=torch.float32
#         ).to(self.device)
        
#         # Get continuous action from policy network using manual index
#         with torch.no_grad():
#             a, logp_a, pi, mu, std, cov, *_ = self.policy_network(
#                 obs_tensor,
#                 deterministic=True if self.hard else False,
#                 bias_config=None if self.hard else {"std": self.std},
#                 manual_indices=torch.tensor([action]).to(self.device)
#             )
            
#         # Execute continuous action in environment
#         continuous_action = a.cpu().numpy()
#         observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        
#         if self.steps > 999:
#             terminated = True
#             truncated = True
        
#         contact_detected = self.env.env.env.env.data.sensordata[0] 
#         if contact_detected > 0.0:
#             terminated = True
#             truncated = True
        
        
#         observation = np.concatenate(
#             (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
#         )
#         self.obs = observation
#         rewards *= 100 
#         return observation, rewards, terminated, truncated, info 
    
#     def reset(self, **kwargs):
#         """
#         Reset the environment.
#         Returns concatenated observation with achieved_goal and desired_goal at the beginning.
#         """
#         self.steps = 0
#         options = {
#             "reset_cell": np.array([3, 1]),
#             "goal_cell": np.array([3, 3]),
#         }
#         observation, info = self.env.reset(options=options, **kwargs)
#         # Concatenate achieved_goal and desired_goal with raw observation
#         observation = np.concatenate(
#             (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
#         )
#         self.obs = observation
#         return observation, info   


class MPAntMazeEnvDiff(gym.Wrapper):
    """
    Environment wrapper that uses manual indices as actions and runs them through the policy network.
    Returns a concatenated observation with achieved_goal and desired_goal at the beginning,
    but only passes the raw observation to the policy network.
    """
    def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
        super().__init__(env)
        self.policy_network = policy_network
        self.device = device
        from gymnasium import spaces
        import numpy as np
        print(f"num_actions: {num_actions}") 
        # Create discrete action space for manual indices
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(31,), dtype=np.float64
        )
        self.steps = 0
        self.std = std
        self.hard = hard
        self.prev_pos = None
        self.prev_dist_to_goal = None
        logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}") 
    
    def step(self, action):
        self.steps += 1
        # Convert raw observation to tensor for policy network (without goals)
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs['observation'], axis=0), 
            dtype=torch.float32
        ).to(self.device)
        
        # Get continuous action from policy network using manual index
        with torch.no_grad():
            a, logp_a, pi, mu, std, cov, *_ = self.policy_network(
                obs_tensor,
                deterministic=True if self.hard else False,
                bias_config=None if self.hard else {"std": self.std},
                manual_indices=torch.tensor([action]).to(self.device)
            )
            
        # Execute continuous action in environment
        continuous_action = a.cpu().numpy()
        observation, _, terminated, truncated, info = self.env.step(continuous_action[0])
        
        # Extract achieved_goal (current position) and desired_goal
        current_pos = observation['achieved_goal']
        goal_pos = observation['desired_goal']
        
        # Calculate current distance to goal
        current_dist_to_goal = np.linalg.norm(current_pos - goal_pos)
        
        # Calculate reward based on change in distance to goal
        if self.prev_dist_to_goal is not None:
            # Reward is the improvement in distance (positive when getting closer)
            rewards = self.prev_dist_to_goal - current_dist_to_goal
        else:
            # On first step, no previous distance exists
            rewards = 0.0
        
        if current_dist_to_goal < 0.1:
            rewards += 1.0
            terminated = True
            truncated = True
        
         
        # Update previous distance for next step
        self.prev_dist_to_goal = current_dist_to_goal
        self.prev_pos = current_pos.copy()
        
        contact_detected = self.env.env.env.env.data.sensordata[0] 
        if contact_detected > 0.0:
            terminated = True
            truncated = True
        
        self.obs = observation
        observation = np.concatenate(
            (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
        )
        # Scale rewards if desired
        rewards *= 100
        
        return observation, rewards, terminated, truncated, info 
    
    def reset(self, **kwargs):
        """
        Reset the environment.
        Returns concatenated observation with achieved_goal and desired_goal at the beginning.
        """
        self.steps = 0
        self.prev_pos = None
        self.prev_dist_to_goal = None
        
        options = {
            # "reset_cell": np.array([3, 1]),
            "goal_cell": np.array([2, 3]),
        }
        observation, info = self.env.reset(options=options, **kwargs)
        
        # Store initial position and distance to goal
        self.prev_pos = observation['achieved_goal'].copy()
        self.prev_dist_to_goal = np.linalg.norm(
            observation['achieved_goal'] - observation['desired_goal']
        )
        self.obs = observation 
        # Concatenate achieved_goal and desired_goal with raw observation
        observation = np.concatenate(
            (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
        )
        return observation, info


# class MPAntMazeDictEnv(gym.Wrapper):
#     """
#     Environment wrapper that uses manual indices as actions and runs them through the policy network.
#     Returns a concatenated observation with achieved_goal and desired_goal at the beginning,
#     but only passes the raw observation to the policy network.
#     """
#     def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
#         super().__init__(env)
#         self.policy_network = policy_network
#         self.device = device
#         from gymnasium import spaces
#         import numpy as np
        
#         # Create discrete action space for manual indices
#         self.action_space = spaces.Discrete(num_actions)
#         self.observation_space = spaces.Box(
#             low=-np.inf, high=np.inf, shape=(31,), dtype=np.float64
#         )
#         self.steps = 0
#         self.std = std
#         self.hard = hard
#         logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}") 
    
#     def step(self, action):
#         self.steps += 1
#         # Convert raw observation to tensor for policy network (without goals)
#         obs_tensor = torch.as_tensor(
#             np.expand_dims(self.obs['observation'], axis=0), 
#             dtype=torch.float32
#         ).to(self.device)
        
#         # Get continuous action from policy network using manual index
#         with torch.no_grad():
#             a, logp_a, pi, mu, std, cov, *_ = self.policy_network(
#                 obs_tensor,
#                 deterministic=True if self.hard else False,
#                 # bias_config=None if self.hard else {"std": self.std},
#                 manual_indices=torch.tensor([action]).to(self.device)
#             )
            
#         # Execute continuous action in environment
#         continuous_action = a.cpu().numpy()
#         observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        
#         if self.steps > 999:
#             terminated = True
#             truncated = True
        
#         contact_detected = self.env.env.env.env.data.sensordata[0] 
#         if contact_detected > 0.0:
#             terminated = True
#             truncated = True
        
#         self.obs = observation 
#         return observation, rewards, terminated, truncated, info 
    
#     def reset(self, **kwargs):
#         """
#         Reset the environment.
#         Returns concatenated observation with achieved_goal and desired_goal at the beginning.
#         """
#         self.steps = 0
#         options = {
#             "reset_cell": np.array([3, 1]),
#             "goal_cell": np.array([3, 3]),
#         }
#         observation, info = self.env.reset(options=options, **kwargs)
#         self.obs = observation
#         return observation, info   


# class MPAntMazeDictEnvJit(gym.Wrapper):
#     """
#     Environment wrapper with optimized policy execution.
#     Uses manual indices as actions and runs them through the policy network.
#     """
#     def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
#         super().__init__(env)
#         self.device = device
#         self.policy_network = policy_network
#         from gymnasium import spaces
        
#         # Create discrete action space for manual indices
#         self.action_space = spaces.Discrete(num_actions)
#         self.observation_space = env.observation_space 
#         self.steps = 0
#         self.std = std
#         self.hard = hard
#         logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        
#         # Pre-allocate tensors for better performance
#         self.obs_tensor = torch.zeros((1, 27), dtype=torch.float32, device=device)
#         self.indices_tensor = torch.zeros(1, dtype=torch.long, device=device)
        
#         # Store parameter directly
#         self.deterministic = hard
        
#         # Create an optimized policy forward method
#         self._setup_policy()
    
#     def _setup_policy(self):
#         """Set up optimized policy execution."""
#         # Create a wrapper module that doesn't use keyword arguments
#         class PolicyWrapper(torch.nn.Module):
#             def __init__(self, policy):
#                 super().__init__()
#                 self.policy = policy
            
#             def forward(self, obs, indices, deterministic):
#                 # Call the forward method directly with positional arguments
#                 # This avoids using keyword arguments completely
#                 return self.policy.forward(obs, deterministic, False, indices)[0]
        
#         try:
#             # Wrap and script the policy
#             self.policy_wrapper = PolicyWrapper(self.policy_network)
#             self.policy_func = torch.jit.script(self.policy_wrapper)
#             logger.info(f"{GREEN}Successfully created JIT-compiled policy wrapper{ENDC}")
#             self.use_jit = True
            
#         except Exception as e:
#             # Fall back to a non-JIT optimized version
#             logger.warning(f"JIT compilation failed: {str(e)}. Using optimized non-JIT fallback.")
#             self.use_jit = False
    
#     def _policy_step(self, obs_tensor, indices_tensor):
#         """Optimized policy step without JIT."""
#         # Call the policy with specific arguments - no JIT but still optimized
#         with torch.no_grad():
#             a, _, _, _, _, _, _ = self.policy_network(
#                 obs_tensor,
#                 deterministic=self.deterministic,
#                 with_logprob=False,
#                 manual_indices=indices_tensor
#             )
#         return a
    
#     def step(self, action):
#         self.steps += 1
        
#         # Update tensors in-place for better performance
#         self.indices_tensor[0] = action
        
#         # Copy observation data into pre-allocated tensor
#         np_obs = self.obs['observation'].reshape(1, -1).astype(np.float32)
#         self.obs_tensor.copy_(torch.from_numpy(np_obs).to(self.device))
        
#         # Get continuous action from policy network
#         if self.use_jit:
#             # Use JIT-compiled function
#             a = self.policy_func(
#                 self.obs_tensor, 
#                 self.indices_tensor, 
#                 self.deterministic
#             )
#         else:
#             # Use optimized non-JIT function
#             a = self._policy_step(self.obs_tensor, self.indices_tensor)
        
#         # Execute continuous action in environment
#         continuous_action = a.cpu().numpy()
#         observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        
#         if self.steps > 999:
#             terminated = True
#             truncated = True
        
#         contact_detected = self.env.env.env.env.data.sensordata[0] 
#         if contact_detected > 0.0:
#             terminated = True
#             truncated = True
        
#         self.obs = observation 
#         return observation, rewards, terminated, truncated, info 
    
#     def reset(self, **kwargs):
#         """
#         Reset the environment.
#         Returns concatenated observation with achieved_goal and desired_goal at the beginning.
#         """
#         self.steps = 0
#         options = {
#             "reset_cell": np.array([3, 1]),
#             "goal_cell": np.array([3, 3]),
#         }
#         observation, info = self.env.reset(options=options, **kwargs)
#         self.obs = observation
#         return observation, info
    
  

# class MPAntMazeEnvJit(gym.Wrapper):
#     """
#     Environment wrapper with optimized policy execution.
#     Uses manual indices as actions and runs them through the policy network.
#     """
#     def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
#         super().__init__(env)
#         logger.debug(f"MPAntMazeEnvJit: {env}")
#         logger.debug(f"MPAntMazeEnvJit: {policy_network}")
#         logger.debug(f"MPAntMazeEnvJit: {device}")
#         logger.debug(f"MPAntMazeEnvJit: {num_actions}")
#         logger.debug(f"MPAntMazeEnvJit: {std}")
#         logger.debug(f"MPAntMazeEnvJit: {hard}")
#         self.device = device
#         self.policy_network = policy_network
#         from gymnasium import spaces
        
#         # Create discrete action space for manual indices
#         self.action_space = spaces.Discrete(num_actions)
#         self.observation_space = spaces.Box(
#             low=-np.inf, high=np.inf, shape=(31,), dtype=np.float64
#         )
#         self.steps = 0
#         self.std = std
#         self.hard = hard
#         logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        
#         # Pre-allocate tensors for better performance
#         self.obs_tensor = torch.zeros((1, 27), dtype=torch.float32, device=device)
#         self.indices_tensor = torch.zeros(1, dtype=torch.long, device=device)
        
#         # Store parameter directly
#         self.deterministic = hard
        
#         # Create an optimized policy forward method
#         self._setup_policy()
    
#     def _setup_policy(self):
#         """Set up optimized policy execution."""
#         # Create a wrapper module that doesn't use keyword arguments
#         class PolicyWrapper(torch.nn.Module):
#             def __init__(self, policy):
#                 super().__init__()
#                 self.policy = policy
            
#             def forward(self, obs, indices, deterministic):
#                 # Call the forward method directly with positional arguments
#                 # This avoids using keyword arguments completely
#                 return self.policy.forward(obs, deterministic, False, indices)[0]
        
#         try:
#             # Wrap and script the policy
#             self.policy_wrapper = PolicyWrapper(self.policy_network)
#             self.policy_func = torch.jit.script(self.policy_wrapper)
#             logger.info(f"{GREEN}Successfully created JIT-compiled policy wrapper{ENDC}")
#             self.use_jit = True
            
#         except Exception as e:
#             # Fall back to a non-JIT optimized version
#             logger.warning(f"JIT compilation failed: {str(e)}. Using optimized non-JIT fallback.")
#             self.use_jit = False
    
#     def _policy_step(self, obs_tensor, indices_tensor):
#         """Optimized policy step without JIT."""
#         # Call the policy with specific arguments - no JIT but still optimized
#         with torch.no_grad():
#             a, _, _, _, _, _, _ = self.policy_network(
#                 obs_tensor,
#                 deterministic=self.deterministic,
#                 with_logprob=False,
#                 manual_indices=indices_tensor
#             )
#         return a
    
#     def step(self, action):
#         self.steps += 1
        
#         # Update tensors in-place for better performance
#         self.indices_tensor[0] = action
        
#         # Copy observation data into pre-allocated tensor
#         np_obs = self.obs[:27].reshape(1, -1).astype(np.float32)
#         self.obs_tensor.copy_(torch.from_numpy(np_obs).to(self.device))
        
#         # Get continuous action from policy network
#         if self.use_jit:
#             # Use JIT-compiled function
#             a = self.policy_func(
#                 self.obs_tensor, 
#                 self.indices_tensor, 
#                 self.deterministic
#             )
#         else:
#             # Use optimized non-JIT function
#             a = self._policy_step(self.obs_tensor, self.indices_tensor)
        
#         # Execute continuous action in environment
#         continuous_action = a.cpu().numpy()
#         observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
#         print(f"{rewards=}") 
#         if self.steps > 999:
#             terminated = True
#             truncated = True
        
#         contact_detected = self.env.env.env.env.data.sensordata[0] 
#         if contact_detected > 0.0:
#             terminated = True
#             truncated = True
        
#         # Concatenate observation components
#         observation = np.concatenate(
#             (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
#         )
#         self.obs = observation
#         # rewards = rewards * 100
#         print(f"{rewards=}") 
#         return observation, rewards, terminated, truncated, info 
    
#     def reset(self, **kwargs):
#         """
#         Reset the environment.
#         Returns concatenated observation with achieved_goal and desired_goal at the beginning.
#         """
#         self.steps = 0
#         options = {
#             "reset_cell": np.array([3, 1]),
#             "goal_cell": np.array([3, 3]),
#         }
#         observation, info = self.env.reset(options=options, **kwargs)
        
#         # observation, info = self.env.reset(**kwargs)
#         # Concatenate achieved_goal and desired_goal with raw observation
#         observation = np.concatenate(
#             (observation['observation'], observation['achieved_goal'], observation['desired_goal'])
#         )
#         self.obs = observation
#         return observation, info
    

# class AntGoal(AntDir):
#     """
#     Ant environment that rewards movement toward the point (5,5) in the arena.
#     The reward is based on the improvement in distance to the goal position.
#     """
#     def __init__(self, test_mode=False, **kwargs):
#         super(AntGoal, self).__init__(test_mode=test_mode, **kwargs)
#         # Target position in the arena
#         self.goal_position = np.array([5.0, 5.0])
#         # Track previous distance to goal for reward calculation
#         self.prev_distance_to_goal = None
        
#     def calculate_reward(self, xy_position_before, xy_position_after):
#         """
#         Calculate reward based on improvement in distance to the goal position (5,5).
#         Reward is positive when agent moves closer to the goal, negative when moving away.
#         """
#         # Calculate current distance to goal
#         current_distance_to_goal = np.linalg.norm(xy_position_after - self.goal_position)
        
#         # If this is the first step, initialize previous distance
#         if self.prev_distance_to_goal is None:
#             self.prev_distance_to_goal = np.linalg.norm(xy_position_before - self.goal_position)
        
#         # Reward is the improvement in distance (positive when getting closer)
#         reward = self.prev_distance_to_goal - current_distance_to_goal
        
#         # Update previous distance for next step
#         self.prev_distance_to_goal = current_distance_to_goal
        
#         # Optional: Scale reward for better learning
#         reward = reward * 10.0
        
#         # Add small bonus if very close to goal
#         if current_distance_to_goal < 0.5:
#             reward += 1.0
            
#         # Add larger bonus for reaching the goal
#         if current_distance_to_goal < 0.2:
#             reward += 5.0
            
#         return reward
    
#     def reset_model(self):
#         """Reset the environment and the distance tracking variable."""
#         self.prev_distance_to_goal = None
#         self.step_count = 0
#         qpos = self.init_qpos 
#         qvel = self.init_qvel
#         self.set_state(qpos, qvel)
#         observation = self._get_obs()
        
#         # Initialize previous distance after reset
#         xy_position = self.get_body_com("torso")[:2].copy()
#         self.prev_distance_to_goal = np.linalg.norm(xy_position - self.goal_position)
        
#         return observation


# class MPAntGoalEnv(MPEnv):
#     """
#     Environment wrapper that uses manual indices as actions and runs them through the policy network.
#     Customized for the AntGoalPos environment to track additional goal-related information.
#     """
#     def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
#         super(MPAntGoalEnv, self).__init__(
#             env=env, 
#             policy_network=policy_network,
#             device=device, 
#             num_actions=num_actions,
#             std=std, 
#             hard=hard
#         )
        
#     def step(self, action):
#         """
#         Takes a discrete action (manual index), runs it through policy network,
#         and executes the resulting continuous action in the environment.
#         Adds goal-related information to the info dictionary.
#         """
#         observation, rewards, terminated, truncated, info = super().step(action)
#         # Add distance to goal in the info dictionary
#         xy_position = self.env.env.env.get_body_com("torso")[:2].copy()
#         distance_to_goal = np.linalg.norm(xy_position - self.env.env.env.goal_position)
#         info["distance_to_goal"] = distance_to_goal
#         info["goal_x"] = self.env.env.env.goal_position[0]
#         info["goal_y"] = self.env.env.env.goal_position[1]
        
#         # Add success indicator
#         info["success"] = float(distance_to_goal < 0.2)
#         if random.random() < 0.01:
#             print(f"info: {info}, steps: {self.steps}, action: {action}") 
#         return observation, rewards, terminated, truncated, info
    
#     def reset(self, **kwargs):
#         """Reset the environment and return the observation and info."""
#         observation, info = super().reset(**kwargs)
        
#         # Add goal information to info dictionary
#         print(self.env)
#         xy_position = self.env.env.env.get_body_com("torso")[:2].copy()
#         info["distance_to_goal"] = np.linalg.norm(xy_position - self.env.env.env.goal_position)
#         info["goal_x"] = self.env.env.env.goal_position[0]
#         info["goal_y"] = self.env.env.env.goal_position[1]
        
#         return observation, info
     

def setup_hrl_environment(args, render_mode="rgb_array"):
    # Build the original policy network
    ac = torch.load(args.path, map_location=device).to(device)
    ac.pi.eval()
    
    if "AntMaze" in args.env:
        base_env = gym.make(
            args.env, 
            width=600, 
            height=600, 
            render_mode=render_mode, 
            include_cfrc_ext_in_observation=False,
            camera_id=0, 
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            continuing_task=False,
            # max_episode_steps=500,
            # maze_smap=example_map,
        )
        # env = MPAntMazeEnv(
        # env = MPAntMazeDictEnvJit(
        env = MPAntMazeEnvDiff(
            base_env, 
            policy_network=ac.pi, 
            device=device, 
            num_actions=args.hrl_nc, 
            std=args.hrl_std, 
            hard=args.hrl_hard
        )
    # elif "ant_goal" in args.env:
    #     base_env = gym.make(
    #         args.env, 
    #         width=600, 
    #         height=600, 
    #         render_mode=render_mode,
    #         camera_id=0, 
    #     )
    #     env = MPAntGoalEnv(
    #         base_env, 
    #         policy_network=ac.pi, 
    #         device=device, 
    #         num_actions=args.hrl_nc, 
    #         std=args.hrl_std, 
    #         hard=args.hrl_hard
    #     )
    
    else:
        print(args.env) 
        base_env = gym.make(
            args.env, 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnv(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    return env

if __name__ == "__main__":
    # env = AntFood(render_mode="human", width=800, height=800)
    # env.reset()
    # for _ in range(1000):
    #     observation, reward, terminated, _, info = env.step(env.action_space.sample()/4)
    #     print(observation[:3].round(2))
    #     # print("="*100)
    #     # for k in info.keys():
    #     #     print(f"{k}: {info[k]}")
            
    #     # time.sleep(0.1) 

    env = gym.make("AntMaze_UMaze-v5", width=600, height=600, render_mode="human")
    options = {
        "reset_cell": np.array([3, 1]),
        "goal_cell": np.array([1, 1]),
    }
    env.reset(options=options)
    for _ in range(1000):
        action = env.action_space.sample()
        observation, reward, terminated, truncated, info = env.step(action)
        # print(observation)