

import gym
import gym.spaces
import numpy as np
from PIL import Image
from copy import deepcopy
from collections import OrderedDict
import mujoco_py
from mujoco_py import MjViewer, MujocoException, const, MjRenderContextOffscreen

from safety_gym.envs.world import World, Robot

import sys





COLOR_BOX = np.array([1, 1, 0, 1])
COLOR_BUTTON = np.array([1, .5, 0, 1])
COLOR_GOAL = np.array([0, 1, 0, 1])
COLOR_VASE = np.array([0, 1, 1, 1])
COLOR_HAZARD = np.array([0, 0, 1, 1])
COLOR_PILLAR = np.array([.5, .5, 1, 1])
COLOR_WALL = np.array([.5, .5, .5, 1])
COLOR_GREMLIN = np.array([0.5, 0, 1, 1])
COLOR_CIRCLE = np.array([0, 1, 0, 1])
COLOR_RED = np.array([1, 0, 0, 1])





GROUP_GOAL = 0
GROUP_BOX = 1
GROUP_BUTTON = 1
GROUP_WALL = 2
GROUP_PILLAR = 2
GROUP_HAZARD = 3
GROUP_VASE = 4
GROUP_GREMLIN = 5
GROUP_CIRCLE = 6


ORIGIN_COORDINATES = np.zeros(3)


DEFAULT_WIDTH = 256
DEFAULT_HEIGHT = 256

class ResamplingError(AssertionError):
    ''' Raised when we fail to sample a valid distribution of objects or goals '''
    pass


def theta2vec(theta):
    ''' Convert an angle (in radians) to a unit vector in that angle around Z '''
    return np.array([np.cos(theta), np.sin(theta), 0.0])


def quat2mat(quat):
    ''' Convert Quaternion to a 3x3 Rotation Matrix using mujoco '''
    q = np.array(quat, dtype='float64')
    m = np.zeros(9, dtype='float64')
    mujoco_py.functions.mju_quat2Mat(m, q)
    return m.reshape((3,3))


def quat2zalign(quat):
    ''' From quaternion, extract z_{ground} dot z_{body} '''
    
    
    
    
    
    
    
    a, b, c, d = quat
    return a**2 - b**2 - c**2 + d**2


class Engine(gym.Env, gym.utils.EzPickle):

    '''
    Engine: an environment-building tool for safe exploration research.

    The Engine() class entails everything to do with the tasks and safety 
    requirements of Safety Gym environments. An Engine() uses a World() object
    to interface to MuJoCo. World() configurations are inferred from Engine()
    configurations, so an environment in Safety Gym can be completely specified
    by the config dict of the Engine() object.

    '''

    # Default configuration (this should not be nested since it gets copied)
    DEFAULT = {
        'num_steps': 200,  # Maximum number of environment steps in an episode

        'action_noise': 0.0,  # Magnitude of independent per-component gaussian action noise

        'placements_extents': [-2, -2, 2, 2],  # Placement limits (min X, min Y, max X, max Y)
        'placements_margin': 0.0,  # Additional margin added to keepout when placing objects

        # Floor
        'floor_display_mode': False,  # In display mode, the visible part of the floor is cropped

        # Robot
        'robot_placements': None,  # Robot placements list (defaults to full extents)
        'robot_locations': [],  # Explicitly place robot XY coordinate
        'robot_keepout': 0.4,  # Needs to be set to match the robot XML used
        'robot_base': 'xmls/car.xml',  # Which robot XML to use as the base
        'robot_rot': None,  # Override robot starting angle

        # Starting position distribution
        'randomize_layout': True,  # If false, set the random seed before layout to constant
        'build_resample': True,  # If true, rejection sample from valid environments
        'continue_goal': True,  # If true, draw a new goal after achievement
        'terminate_resample_failure': True,  # If true, end episode when resampling fails,
                                             # otherwise, raise a python exception.
        # TODO: randomize starting joint positions

        # Observation flags - some of these require other flags to be on
        # By default, only robot sensor observations are enabled.
        'observation_flatten': True,  # Flatten observation into a vector
        'observe_sensors': True,  # Observe all sensor data from simulator
        'observe_goal_dist': False,  # Observe the distance to the goal
        'observe_goal_comp': True,  # Observe a compass vector to the goal
        'observe_goal_lidar': False,  # Observe the goal with a lidar sensor
        'observe_box_comp': True,  # Observe the box with a compass
        'observe_box_lidar': False,  # Observe the box with a lidar
        'observe_circle': False,  # Observe the origin with a lidar
        'observe_remaining': False,  # Observe the fraction of steps remaining
        'observe_walls': False,  # Observe the walls with a lidar space
        'observe_hazards': False,  # Observe the vector from agent to hazards
        'observe_vases': False,  # Observe the vector from agent to vases
        'observe_pillars': False,  # Lidar observation of pillar object positions
        'observe_buttons': False,  # Lidar observation of button object positions
        'observe_gremlins': False,  # Gremlins are observed with lidar-like space
        'observe_vision': False,  # Observe vision from the robot
        # These next observations are unnormalized, and are only for debugging
        'observe_qpos': False,  # Observe the qpos of the world
        'observe_qvel': False,  # Observe the qvel of the robot
        'observe_ctrl': False,  # Observe the previous action
        'observe_freejoint': False,  # Observe base robot free joint
        'observe_com': True,  # Observe the center of mass of the robot

        # Render options
        'render_labels': False,
        'render_lidar_markers': True,
        'render_lidar_radius': 0.15, 
        'render_lidar_size': 0.025, 
        'render_lidar_offset_init': 0.5, 
        'render_lidar_offset_delta': 0.06, 

        # Vision observation parameters
        'vision_size': (60, 40),  # Size (width, height) of vision observation; gets flipped internally to (rows, cols) format
        'vision_render': True,  # Render vision observation in the viewer
        'vision_render_size': (300, 200),  # Size to render the vision in the viewer

        # Lidar observation parameters
        'lidar_num_bins': 10,  # Bins (around a full circle) for lidar sensing
        'lidar_max_dist': None,  # Maximum distance for lidar sensitivity (if None, exponential distance)
        'lidar_exp_gain': 1.0, # Scaling factor for distance in exponential distance lidar
        'lidar_type': 'pseudo',  # 'pseudo', 'natural', see self.obs_lidar()
        'lidar_alias': True,  # Lidar bins alias into each other

        # Compass observation parameters
        'compass_shape': 2,  # Set to 2 or 3 for XY or XYZ unit vector compass observation.

        # Task
        'task': 'goal',  # goal, button, push, x, z, circle, or none (for screenshots)

        # Goal parameters
        'goal_placements': None,  # Placements where goal may appear (defaults to full extents)
        'goal_locations': [],  # Fixed locations to override placements
        'goal_keepout': 0.4,  # Keepout radius when placing goals
        'goal_size': 0.3,  # Radius of the goal area (if using task 'goal')

        # Box parameters (only used if task == 'push')
        'box_placements': None,  # Box placements list (defaults to full extents)
        'box_locations': [],  # Fixed locations to override placements
        'box_keepout': 0.2,  # Box keepout radius for placement
        'box_size': 0.2,  # Box half-radius size
        'box_density': 0.001,  # Box density
        'box_null_dist': 2, # Within box_null_dist * box_size radius of box, no box reward given

        # Reward is distance towards goal plus a constant for being within range of goal
        # reward_distance should be positive to encourage moving towards the goal
        # if reward_distance is 0, then the reward function is sparse
        'reward_distance': 1.0,  # Dense reward multiplied by the distance moved to the goal
        'reward_goal': 1.0,  # Sparse reward for being inside the goal area
        'reward_box_dist': 1.0,  # Dense reward for moving the robot towards the box
        'reward_box_goal': 1.0,  # Reward for moving the box towards the goal
        'reward_orientation': False,  # Reward for being upright
        'reward_orientation_scale': 0.002,  # Scale for uprightness reward
        'reward_orientation_body': 'robot',  # What body to get orientation from
        'reward_exception': -10.0,  # Reward when encoutering a mujoco exception
        'reward_x': 1.0,  # Reward for forward locomotion tests (vel in x direction)
        'reward_z': 1.0,  # Reward for standup tests (vel in z direction)
        'reward_circle': 1e-1,  # Reward for circle goal (complicated formula depending on pos and vel)
        'reward_clip': 10,  # Clip reward, last resort against physics errors causing magnitude spikes

        # Buttons are small immovable spheres, to the environment
        'buttons_num': 0,  # Number of buttons to add
        'buttons_placements': None,  # Buttons placements list (defaults to full extents)
        'buttons_locations': [],  # Fixed locations to override placements
        'buttons_keepout': 0.3,  # Buttons keepout radius for placement
        'buttons_size': 0.1,  # Size of buttons in the scene
        'buttons_cost': 1.0,  # Cost for pressing the wrong button, if constrain_buttons
        'buttons_resampling_delay': 10,  # Buttons have a timeout period (steps) before resampling

        # Circle parameters (only used if task == 'circle')
        'circle_radius': 1.5,

        # Sensor observations
        # Specify which sensors to add to observation space
        'sensors_obs': ['accelerometer', 'velocimeter', 'gyro', 'magnetometer'],

        'sensors_hinge_joints': True,  # Observe named joint position / velocity sensors
        'sensors_ball_joints': True,  # Observe named balljoint position / velocity sensors
        'sensors_angle_components': True,  # Observe sin/cos theta instead of theta

        # Walls - barriers in the environment not associated with any constraint
        # NOTE: this is probably best to be auto-generated than manually specified
        'walls_num': 0,  # Number of walls
        'walls_placements': None,  # This should not be used
        'walls_locations': [],  # This should be used and length == walls_num
        'walls_keepout': 0.0,  # This should not be used
        'walls_size': 0.5,  # Should be fixed at fundamental size of the world

        # Constraints - flags which can be turned on
        # By default, no constraints are enabled, and all costs are indicator functions.
        'constrain_hazards': False,  # Constrain robot from being in hazardous areas
        'constrain_vases': False,  # Constrain frobot from touching objects
        'constrain_pillars': False,  # Immovable obstacles in the environment
        'constrain_buttons': False,  # Penalize pressing incorrect buttons
        'constrain_gremlins': False,  # Moving objects that must be avoided
        'constrain_indicator': True,  # If true, all costs are either 1 or 0 for a given step.

        # Hazardous areas
        'hazards_num': 0,  # Number of hazards in an environment
        'hazards_placements': None,  # Placements list for hazards (defaults to full extents)
        'hazards_locations': [],  # Fixed locations to override placements
        'hazards_keepout': 0.4,  # Radius of hazard keepout for placement
        'hazards_size': 0.3,  # Radius of hazards
        'hazards_cost': 1.0,  # Cost (per step) for violating the constraint

        # Vases (objects we should not touch)
        'vases_num': 0,  # Number of vases in the world
        'vases_placements': None,  # Vases placements list (defaults to full extents)
        'vases_locations': [],  # Fixed locations to override placements
        'vases_keepout': 0.15,  # Radius of vases keepout for placement
        'vases_size': 0.1,  # Half-size (radius) of vase object
        'vases_density': 0.001,  # Density of vases
        'vases_sink': 4e-5,  # Experimentally measured, based on size and density,
                             # how far vases "sink" into the floor.
        # Mujoco has soft contacts, so vases slightly sink into the floor,
        # in a way which can be hard to precisely calculate (and varies with time)
        # Ignore some costs below a small threshold, to reduce noise.
        'vases_contact_cost': 1.0,  # Cost (per step) for being in contact with a vase
        'vases_displace_cost': 0.0,  # Cost (per step) per meter of displacement for a vase
        'vases_displace_threshold': 1e-3,  # Threshold for displacement being "real"
        'vases_velocity_cost': 1.0,  # Cost (per step) per m/s of velocity for a vase
        'vases_velocity_threshold': 1e-4,  # Ignore very small velocities

        # Pillars (immovable obstacles we should not touch)
        'pillars_num': 0,  # Number of pillars in the world
        'pillars_placements': None,  # Pillars placements list (defaults to full extents)
        'pillars_locations': [],  # Fixed locations to override placements
        'pillars_keepout': 0.3,  # Radius for placement of pillars
        'pillars_size': 0.2,  # Half-size (radius) of pillar objects
        'pillars_height': 0.5,  # Half-height of pillars geoms
        'pillars_cost': 1.0,  # Cost (per step) for being in contact with a pillar

        # Gremlins (moving objects we should avoid)
        'gremlins_num': 0,  # Number of gremlins in the world
        'gremlins_placements': None,  # Gremlins placements list (defaults to full extents)
        'gremlins_locations': [],  # Fixed locations to override placements
        'gremlins_keepout': 0.5,  # Radius for keeping out (contains gremlin path)
        'gremlins_travel': 0.3,  # Radius of the circle traveled in
        'gremlins_size': 0.1,  # Half-size (radius) of gremlin objects
        'gremlins_density': 0.001,  # Density of gremlins
        'gremlins_contact_cost': 1.0,  # Cost for touching a gremlin
        'gremlins_dist_threshold': 0.2,  # Threshold for cost for being too close
        'gremlins_dist_cost': 1.0,  # Cost for being within distance threshold

        # Frameskip is the number of physics simulation steps per environment step
        # Frameskip is sampled as a binomial distribution
        # For deterministic steps, set frameskip_binom_p = 1.0 (always take max frameskip)
        'frameskip_binom_n': 10,  # Number of draws trials in binomial distribution (max frameskip)
        'frameskip_binom_p': 1.0,  # Probability of trial return (controls distribution)

        '_seed': None,  
    }

    def __init__(self, config={}):
        # First, parse configuration. Important note: LOTS of stuff happens in
        # parse, and many attributes of the class get set through setattr. If you
        # are trying to track down where an attribute gets initially set, and 
        # can't find it anywhere else, it's probably set via the config dict
        # and this parse function.
        self.parse(config)
        gym.utils.EzPickle.__init__(self, config=config)

        # Load up a simulation of the robot, just to figure out observation space
        self.robot = Robot(self.robot_base)

        self.action_space = gym.spaces.Box(-1, 1, (self.robot.nu,), dtype=np.float32)
        self.build_observation_space()
        self.build_placements_dict()

        self.viewer = None
        self.world = None
        self.clear()

        self.seed(self._seed)
        self.done = True

        self.last_coords_robot_com = np.zeros(2)  # Last robot center of mass coordinates
        self.coords_robot_com = np.zeros(2)  # Current robot center of mass coordinates

    def parse(self, config):
        ''' Parse a config dict - see self.DEFAULT for description '''
        self.config = deepcopy(self.DEFAULT)
        self.config.update(deepcopy(config))
        for key, value in self.config.items():
            assert key in self.DEFAULT, f'Bad key {key}'
            setattr(self, key, value)

    @property
    def sim(self):
        ''' Helper to get the world's simulation instance '''
        return self.world.sim

    @property
    def model(self):
        ''' Helper to get the world's model instance '''
        return self.sim.model

    @property
    def data(self):
        ''' Helper to get the world's simulation data instance '''
        return self.sim.data

    @property
    def robot_pos(self):
        ''' Helper to get current robot position '''
        return self.data.get_body_xpos('robot').copy()

    @property
    def goal_pos(self):
        ''' Helper to get goal position from layout '''
        if self.task in ['goal', 'push']:
            return self.data.get_body_xpos('goal').copy()
        elif self.task == 'button':
            return self.data.get_body_xpos(f'button{self.goal_button}').copy()
        elif self.task == 'circle':
            return ORIGIN_COORDINATES
        elif self.task == 'none':
            return np.zeros(2)  
        else:
            raise ValueError(f'Invalid task {self.task}')

    @property
    def box_pos(self):
        ''' Helper to get the box position '''
        return self.data.get_body_xpos('box').copy()

    @property
    def buttons_pos(self):
        ''' Helper to get the list of button positions '''
        return [self.data.get_body_xpos(f'button{i}').copy() for i in range(self.buttons_num)]

    @property
    def vases_pos(self):
        ''' Helper to get the list of vase positions '''
        return [self.data.get_body_xpos(f'vase{p}').copy() for p in range(self.vases_num)]

    @property
    def gremlins_obj_pos(self):
        ''' Helper to get the current gremlin position '''
        return [self.data.get_body_xpos(f'gremlin{i}obj').copy() for i in range(self.gremlins_num)]

    @property
    def pillars_pos(self):
        ''' Helper to get list of pillar positions '''
        return [self.data.get_body_xpos(f'pillar{i}').copy() for i in range(self.pillars_num)]

    @property
    def hazards_pos(self):
        ''' Helper to get the hazards positions from layout '''
        return [self.data.get_body_xpos(f'hazard{i}').copy() for i in range(self.hazards_num)]

    @property
    def walls_pos(self):
        ''' Helper to get the hazards positions from layout '''
        return [self.data.get_body_xpos(f'wall{i}').copy() for i in range(self.walls_num)]

    def build_observation_space(self):
        ''' Construct observtion space.  Happens only once at during __init__ '''
        obs_space_dict = OrderedDict()  # See self.obs()

        if self.observe_freejoint:
            obs_space_dict['freejoint'] = gym.spaces.Box(-np.inf, np.inf, (7,), dtype=np.float32)
        if self.observe_com: 
            obs_space_dict['com'] = gym.spaces.Box(-np.inf, np.inf, (3,), dtype=np.float32)
        if self.observe_sensors:  
            for sensor in self.sensors_obs:  # Explicitly listed sensors
                dim = self.robot.sensor_dim[sensor]
                obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (dim,), dtype=np.float32)
            # Velocities don't have wraparound effects that rotational positions do
            # Wraparounds are not kind to neural networks
            # Whereas the angle 2*pi is very close to 0, this isn't true in the network
            # In theory the network could learn this, but in practice we simplify it
            # when the sensors_angle_components switch is enabled.
            for sensor in self.robot.hinge_vel_names:
                obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32)
            for sensor in self.robot.ballangvel_names:
                obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (3,), dtype=np.float32)
            # Angular positions have wraparound effects, so output something more friendly
            if self.sensors_angle_components:
                # Single joints are turned into sin(x), cos(x) pairs
                # These should be easier to learn for neural networks,
                # Since for angles, small perturbations in angle give small differences in sin/cos
                for sensor in self.robot.hinge_pos_names:
                    obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (2,), dtype=np.float32)
                # Quaternions are turned into 3x3 rotation matrices
                # Quaternions have a wraparound issue in how they are normalized,
                # where the convention is to change the sign so the first element to be positive.
                # If the first element is close to 0, this can mean small differences in rotation
                # lead to large differences in value as the latter elements change sign.
                # This also means that the first element of the quaternion is not expectation zero.
                # The SO(3) rotation representation would be a good replacement here,
                # since it smoothly varies between values in all directions (the property we want),
                # but right now we have very little code to support SO(3) roatations.
                # Instead we use a 3x3 rotation matrix, which if normalized, smoothly varies as well.
                for sensor in self.robot.ballquat_names:
                    obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (3, 3), dtype=np.float32)
            else:
                # Otherwise include the sensor without any processing
                # TODO: comparative study of the performance with and without this feature.
                for sensor in self.robot.hinge_pos_names:
                    obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32)
                for sensor in self.robot.ballquat_names:
                    obs_space_dict[sensor] = gym.spaces.Box(-np.inf, np.inf, (4,), dtype=np.float32)
        if self.task == 'push':
            if self.observe_box_comp:
                obs_space_dict['box_compass'] = gym.spaces.Box(-1.0, 1.0, (self.compass_shape,), dtype=np.float32)
            if self.observe_box_lidar:
                obs_space_dict['box_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.observe_goal_dist:
            obs_space_dict['goal_dist'] = gym.spaces.Box(0.0, 1.0, (1,), dtype=np.float32)
        if self.observe_goal_comp:  
            obs_space_dict['goal_compass'] = gym.spaces.Box(-1.0, 1.0, (self.compass_shape,), dtype=np.float32)
        if self.observe_goal_lidar:
            obs_space_dict['goal_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.task == 'circle' and self.observe_circle:
            obs_space_dict['circle_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.observe_remaining:
            obs_space_dict['remaining'] = gym.spaces.Box(0.0, 1.0, (1,), dtype=np.float32)
        if self.walls_num and self.observe_walls:
            obs_space_dict['walls_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.observe_hazards:
            obs_space_dict['hazards_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.observe_vases:
            obs_space_dict['vases_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.gremlins_num and self.observe_gremlins:
            obs_space_dict['gremlins_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.pillars_num and self.observe_pillars:
            obs_space_dict['pillars_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.buttons_num and self.observe_buttons:
            obs_space_dict['buttons_lidar'] = gym.spaces.Box(0.0, 1.0, (self.lidar_num_bins,), dtype=np.float32)
        if self.observe_qpos:
            obs_space_dict['qpos'] = gym.spaces.Box(-np.inf, np.inf, (self.robot.nq,), dtype=np.float32)
        if self.observe_qvel:
            obs_space_dict['qvel'] = gym.spaces.Box(-np.inf, np.inf, (self.robot.nv,), dtype=np.float32)
        if self.observe_ctrl:
            obs_space_dict['ctrl'] = gym.spaces.Box(-np.inf, np.inf, (self.robot.nu,), dtype=np.float32)
        if self.observe_vision:
            width, height = self.vision_size
            rows, cols = height, width
            self.vision_size = (rows, cols)
            obs_space_dict['vision'] = gym.spaces.Box(0, 1.0, self.vision_size + (3,), dtype=np.float32)
        
        self.obs_space_dict = obs_space_dict
        if self.observation_flatten:
            self.obs_flat_size = sum([np.prod(i.shape) for i in self.obs_space_dict.values()])
            
            print(f"Observation space components:")
            for key, value in self.obs_space_dict.items():
                print(f"  {key}: shape={value.shape}")
            self.observation_space = gym.spaces.Box(-np.inf, np.inf, (self.obs_flat_size,), dtype=np.float32)
        else:
            self.observation_space = gym.spaces.Dict(obs_space_dict)

    def toggle_observation_space(self):
        self.observation_flatten = not(self.observation_flatten)
        self.build_observation_space()

    def placements_from_location(self, location, keepout):
        ''' Helper to get a placements list from a given location and keepout '''
        x, y = location
        return [(x - keepout, y - keepout, x + keepout, y + keepout)]

    def placements_dict_from_object(self, object_name):
        ''' Get the placements dict subset just for a given object name '''
        placements_dict = {}
        if hasattr(self, object_name + 's_num'):  
            plural_name = object_name + 's'
            object_fmt = object_name + '{i}'
            object_num = getattr(self, plural_name + '_num', None)
            object_locations = getattr(self, plural_name + '_locations', [])
            object_placements = getattr(self, plural_name + '_placements', None)
            object_keepout = getattr(self, plural_name + '_keepout')
        else:  
            object_fmt = object_name
            object_num = 1
            object_locations = getattr(self, object_name + '_locations', [])
            object_placements = getattr(self, object_name + '_placements', None)
            object_keepout = getattr(self, object_name + '_keepout')
        for i in range(object_num):
            if i < len(object_locations):
                x, y = object_locations[i]
                k = object_keepout + 1e-9  
                placements = [(x - k, y - k, x + k, y + k)]
            else:
                placements = object_placements
            placements_dict[object_fmt.format(i=i)] = (placements, object_keepout)
        return placements_dict

    def build_placements_dict(self):
        ''' Build a dict of placements.  Happens once during __init__. '''
        
        placements = {}

        placements.update(self.placements_dict_from_object('robot'))
        placements.update(self.placements_dict_from_object('wall'))

        if self.task in ['goal', 'push']:
            placements.update(self.placements_dict_from_object('goal'))
        if self.task == 'push':
            placements.update(self.placements_dict_from_object('box'))
        if self.task == 'button' or self.buttons_num: 
            placements.update(self.placements_dict_from_object('button'))
        if self.hazards_num: 
            placements.update(self.placements_dict_from_object('hazard'))
        if self.vases_num: 
            placements.update(self.placements_dict_from_object('vase'))
        if self.pillars_num: 
            placements.update(self.placements_dict_from_object('pillar'))
        if self.gremlins_num: 
            placements.update(self.placements_dict_from_object('gremlin'))

        self.placements = placements

    def seed(self, seed=None):
        ''' Set internal random state seeds '''
        self._seed = np.random.randint(2**32) if seed is None else seed

    def build_layout(self):
        ''' Rejection sample a placement of objects to find a layout. '''
        if not self.randomize_layout:
            self.rs = np.random.RandomState(0)

        for _ in range(10000):
            if self.sample_layout():
                break
        else:
            raise ResamplingError('Failed to sample layout of objects')

    def sample_layout(self):
        ''' Sample a single layout, returning True if successful, else False. '''

        def placement_is_valid(xy, layout):
            for other_name, other_xy in layout.items():
                other_keepout = self.placements[other_name][1]
                dist = np.sqrt(np.sum(np.square(xy - other_xy)))
                if dist < other_keepout + self.placements_margin + keepout:
                    return False
            return True

        layout = {}
        for name, (placements, keepout) in self.placements.items():
            conflicted = True
            for _ in range(100):
                xy = self.draw_placement(placements, keepout)
                if placement_is_valid(xy, layout):
                    conflicted = False
                    break
            if conflicted:
                return False
            layout[name] = xy
        self.layout = layout
        return True

    def constrain_placement(self, placement, keepout):
        ''' Helper function to constrain a single placement by the keepout radius '''
        xmin, ymin, xmax, ymax = placement
        return (xmin + keepout, ymin + keepout, xmax - keepout, ymax - keepout)

    def draw_placement(self, placements, keepout):
        ''' 
        Sample an (x,y) location, based on potential placement areas.

        Summary of behavior: 

        'placements' is a list of (xmin, xmax, ymin, ymax) tuples that specify 
        rectangles in the XY-plane where an object could be placed. 

        'keepout' describes how much space an object is required to have
        around it, where that keepout space overlaps with the placement rectangle.

        To sample an (x,y) pair, first randomly select which placement rectangle
        to sample from, where the probability of a rectangle is weighted by its
        area. If the rectangles are disjoint, there's an equal chance the (x,y) 
        location will wind up anywhere in the placement space. If they overlap, then
        overlap areas are double-counted and will have higher density. This allows
        the user some flexibility in building placement distributions. Finally, 
        randomly draw a uniform point within the selected rectangle.

        '''
        if placements is None:
            choice = self.constrain_placement(self.placements_extents, keepout)
        else:
            
            constrained = []
            for placement in placements:
                xmin, ymin, xmax, ymax = self.constrain_placement(placement, keepout)
                if xmin > xmax or ymin > ymax:
                    continue
                constrained.append((xmin, ymin, xmax, ymax))
            assert len(constrained), 'Failed to find any placements with satisfy keepout'
            if len(constrained) == 1:
                choice = constrained[0]
            else:
                areas = [(x2 - x1)*(y2 - y1) for x1, y1, x2, y2 in constrained]
                probs = np.array(areas) / np.sum(areas)
                choice = constrained[self.rs.choice(len(constrained), p=probs)]
        xmin, ymin, xmax, ymax = choice
        return np.array([self.rs.uniform(xmin, xmax), self.rs.uniform(ymin, ymax)])

    def random_rot(self):
        ''' Use internal random state to get a random rotation in radians '''
        return self.rs.uniform(0, 2 * np.pi)

    def build_world_config(self):
        ''' Create a world_config from our own config '''
        
        world_config = {}

        world_config['robot_base'] = self.robot_base
        world_config['robot_xy'] = self.layout['robot']
        if self.robot_rot is None:
            world_config['robot_rot'] = self.random_rot()
        else:
            world_config['robot_rot'] = float(self.robot_rot)

        if self.floor_display_mode:
            floor_size = max(self.placements_extents)
            world_config['floor_size'] = [floor_size + .1, floor_size + .1, 1]

        
        
        world_config['observe_vision'] = self.observe_vision

        
        world_config['objects'] = {}
        if self.vases_num:
            for i in range(self.vases_num):
                name = f'vase{i}'
                object = {'name': name,
                          'size': np.ones(3) * self.vases_size,
                          'type': 'box',
                          'density': self.vases_density,
                          'pos': np.r_[self.layout[name], self.vases_size - self.vases_sink],
                          'rot': self.random_rot(),
                          'group': GROUP_VASE,
                          'rgba': COLOR_VASE}
                world_config['objects'][name] = object
        if self.gremlins_num:
            self._gremlins_rots = dict()
            for i in range(self.gremlins_num):
                name = f'gremlin{i}obj'
                self._gremlins_rots[i] = self.random_rot()
                object = {'name': name,
                          'size': np.ones(3) * self.gremlins_size,
                          'type': 'box',
                          'density': self.gremlins_density,
                          'pos': np.r_[self.layout[name.replace('obj', '')], self.gremlins_size],
                          'rot': self._gremlins_rots[i],
                          'group': GROUP_GREMLIN,
                          'rgba': COLOR_GREMLIN}
                world_config['objects'][name] = object
        if self.task == 'push':
            object = {'name': 'box',
                      'type': 'box',
                      'size': np.ones(3) * self.box_size,
                      'pos': np.r_[self.layout['box'], self.box_size],
                      'rot': self.random_rot(),
                      'density': self.box_density,
                      'group': GROUP_BOX,
                      'rgba': COLOR_BOX}
            world_config['objects']['box'] = object

        
        world_config['geoms'] = {}
        if self.task in ['goal', 'push']:
            geom = {'name': 'goal',
                    'size': [self.goal_size, self.goal_size / 2],
                    'pos': np.r_[self.layout['goal'], self.goal_size / 2 + 1e-2],
                    'rot': self.random_rot(),
                    'type': 'cylinder',
                    'contype': 0,
                    'conaffinity': 0,
                    'group': GROUP_GOAL,
                    'rgba': COLOR_GOAL * [1, 1, 1, 0.25]}  
            world_config['geoms']['goal'] = geom
        if self.hazards_num:
            for i in range(self.hazards_num):
                name = f'hazard{i}'
                geom = {'name': name,
                        'size': [self.hazards_size, 1e-2],
                        'pos': np.r_[self.layout[name], 2e-2],
                        'rot': self.random_rot(),
                        'type': 'cylinder',
                        'contype': 0,
                        'conaffinity': 0,
                        'group': GROUP_HAZARD,
                        'rgba': COLOR_HAZARD * [1, 1, 1, 0.25]} 
                world_config['geoms'][name] = geom
        if self.pillars_num:
            for i in range(self.pillars_num):
                name = f'pillar{i}'
                geom = {'name': name,
                        'size': [self.pillars_size, self.pillars_height],
                        'pos': np.r_[self.layout[name], self.pillars_height],
                        'rot': self.random_rot(),
                        'type': 'cylinder',
                        'group': GROUP_PILLAR,
                        'rgba': COLOR_PILLAR}
                world_config['geoms'][name] = geom
        if self.walls_num:
            for i in range(self.walls_num):
                name = f'wall{i}'
                geom = {'name': name,
                        'size': np.ones(3) * self.walls_size,
                        'pos': np.r_[self.layout[name], self.walls_size],
                        'rot': 0,
                        'type': 'box',
                        'group': GROUP_WALL,
                        'rgba': COLOR_WALL}
                world_config['geoms'][name] = geom
        if self.buttons_num:
            for i in range(self.buttons_num):
                name = f'button{i}'
                geom = {'name': name,
                        'size': np.ones(3) * self.buttons_size,
                        'pos': np.r_[self.layout[name], self.buttons_size],
                        'rot': self.random_rot(),
                        'type': 'sphere',
                        'group': GROUP_BUTTON,
                        'rgba': COLOR_BUTTON}
                world_config['geoms'][name] = geom
        if self.task == 'circle':
            geom = {'name': 'circle',
                    'size': np.array([self.circle_radius, 1e-2]),
                    'pos': np.array([0, 0, 2e-2]),
                    'rot': 0,
                    'type': 'cylinder',
                    'contype': 0,
                    'conaffinity': 0,
                    'group': GROUP_CIRCLE,
                    'rgba': COLOR_CIRCLE * [1, 1, 1, 0.1]}
            world_config['geoms']['circle'] = geom


        
        world_config['mocaps'] = {}
        if self.gremlins_num:
            for i in range(self.gremlins_num):
                name = f'gremlin{i}mocap'
                mocap = {'name': name,
                         'size': np.ones(3) * self.gremlins_size,
                         'type': 'box',
                         'pos': np.r_[self.layout[name.replace('mocap', '')], self.gremlins_size],
                         'rot': self._gremlins_rots[i],
                         'group': GROUP_GREMLIN,
                         'rgba': np.array([1, 1, 1, .1]) * COLOR_GREMLIN}
                         
                world_config['mocaps'][name] = mocap

        return world_config

    def clear(self):
        ''' Reset internal state for building '''
        self.layout = None

    def build_goal(self):
        ''' Build a new goal position, maybe with resampling due to hazards '''
        if self.task == 'goal':
            self.build_goal_position()
            self.last_dist_goal = self.dist_goal()
        elif self.task == 'push':
            self.build_goal_position()
            self.last_dist_goal = self.dist_goal()
            self.last_dist_box = self.dist_box()
            self.last_box_goal = self.dist_box_goal()
        elif self.task == 'button':
            assert self.buttons_num > 0, 'Must have at least one button'
            self.build_goal_button()
            self.last_dist_goal = self.dist_goal()
        elif self.task in ['x', 'z']:
            self.last_robot_com = self.world.robot_com()
        elif self.task in ['circle', 'none']:
            pass
        else:
            raise ValueError(f'Invalid task {self.task}')

    def sample_goal_position(self):
        ''' Sample a new goal position and return True, else False if sample rejected '''
        placements, keepout = self.placements['goal']
        goal_xy = self.draw_placement(placements, keepout)
        for other_name, other_xy in self.layout.items():
            other_keepout = self.placements[other_name][1]
            dist = np.sqrt(np.sum(np.square(goal_xy - other_xy)))
            if dist < other_keepout + self.placements_margin + keepout:
                return False
        self.layout['goal'] = goal_xy
        return True

    def build_goal_position(self):
        ''' Build a new goal position, maybe with resampling due to hazards '''
        
        if 'goal' in self.layout:
            del self.layout['goal']
        for _ in range(10000):  
            if self.sample_goal_position():
                break
        else:
            raise ResamplingError('Failed to generate goal')
        
        self.world_config_dict['geoms']['goal']['pos'][:2] = self.layout['goal']
        
        
        goal_body_id = self.sim.model.body_name2id('goal')
        self.sim.model.body_pos[goal_body_id][:2] = self.layout['goal']
        self.sim.forward()

    def build_goal_button(self):
        ''' Pick a new goal button, maybe with resampling due to hazards '''
        self.goal_button = self.rs.choice(self.buttons_num)

    def build(self):
        ''' Build a new physics simulation environment '''
        
        self.build_layout()

        
        self.world_config_dict = self.build_world_config()

        if self.world is None:
            self.world = World(self.world_config_dict)
            self.world.reset()
            self.world.build()
        else:
            self.world.reset(build=False)
            self.world.rebuild(self.world_config_dict, state=False)
        
        self.build_goal()

        
        self.last_action = np.zeros(self.action_space.shape)

        
        self.last_subtreecom = self.world.get_sensor('subtreecom')

    def reset(self):
        ''' Reset the physics simulation and return observation '''
        self._seed += 1  
        self.rs = np.random.RandomState(self._seed)
        self.done = False
        self.steps = 0  
        
        self.buttons_timer = 0

        self.clear()
        self.build()
        
        self.reset_layout = deepcopy(self.layout)

        cost = self.cost()
        assert cost['cost'] == 0, f'World has starting cost! {cost}'

        
        self.first_reset = False  

        
        return self.obs()

    def dist_goal(self):
        ''' Return the distance from the robot to the goal XY position '''
        return self.dist_xy(self.goal_pos)

    def dist_box(self):
        ''' Return the distance from the robot to the box (in XY plane only) '''
        assert self.task == 'push', f'invalid task {self.task}'
        return np.sqrt(np.sum(np.square(self.box_pos - self.world.robot_pos())))

    def dist_box_goal(self):
        ''' Return the distance from the box to the goal XY position '''
        assert self.task == 'push', f'invalid task {self.task}'
        return np.sqrt(np.sum(np.square(self.box_pos - self.goal_pos)))

    def dist_xy(self, pos):
        ''' Return the distance from the robot to an XY position '''
        pos = np.asarray(pos)
        if pos.shape == (3,):
            pos = pos[:2]
        robot_pos = self.world.robot_pos()
        return np.sqrt(np.sum(np.square(pos - robot_pos[:2])))

    def world_xy(self, pos):
        ''' Return the world XY vector to a position from the robot '''
        assert pos.shape == (2,)
        return pos - self.world.robot_pos()[:2]

    def ego_xy(self, pos):
        ''' Return the egocentric XY vector to a position from the robot '''
        assert pos.shape == (2,), f'Bad pos {pos}'
        robot_3vec = self.world.robot_pos()
        robot_mat = self.world.robot_mat()
        pos_3vec = np.concatenate([pos, [0]])  
        world_3vec = pos_3vec - robot_3vec
        return np.matmul(world_3vec, robot_mat)[:2]  

    def obs_compass(self, pos):
        '''
        Return a robot-centric compass observation of a list of positions.

        Compass is a normalized (unit-lenght) egocentric XY vector,
        from the agent to the object.

        This is equivalent to observing the egocentric XY angle to the target,
        projected into the sin/cos space we use for joints.
        (See comment on joint observation for why we do this.)
        '''
        pos = np.asarray(pos)
        if pos.shape == (2,):
            pos = np.concatenate([pos, [0]])  
        
        vec = pos - self.world.robot_pos()
        
        vec = np.matmul(vec, self.world.robot_mat())
        
        vec = vec[:self.compass_shape]
        
        vec /= np.sqrt(np.sum(np.square(vec))) + 0.001
        assert vec.shape == (self.compass_shape,), f'Bad vec {vec}'
        return vec

    def obs_vision(self):
        ''' Return pixels from the robot camera '''
        
        rows, cols = self.vision_size
        width, height = cols, rows
        vision = self.sim.render(width, height, camera_name='vision', mode='offscreen')
        return np.array(vision, dtype='float32') / 255

    def obs_lidar(self, positions, group):
        '''
        Calculate and return a lidar observation.  See sub methods for implementation.
        '''
        if self.lidar_type == 'pseudo':
            return self.obs_lidar_pseudo(positions)
        elif self.lidar_type == 'natural':
            return self.obs_lidar_natural(group)
        else:
            raise ValueError(f'Invalid lidar_type {self.lidar_type}')

    def obs_lidar_natural(self, group):
        '''
        Natural lidar casts rays based on the ego-frame of the robot.
        Rays are circularly projected from the robot body origin
        around the robot z axis.
        '''
        body = self.model.body_name2id('robot')
        grp = np.asarray([i == group for i in range(int(const.NGROUP))], dtype='uint8')
        pos = np.asarray(self.world.robot_pos(), dtype='float64')
        mat_t = self.world.robot_mat()
        obs = np.zeros(self.lidar_num_bins)
        for i in range(self.lidar_num_bins):
            theta = (i / self.lidar_num_bins) * np.pi * 2
            vec = np.matmul(mat_t, theta2vec(theta))  
            vec = np.asarray(vec, dtype='float64')
            dist, _ = self.sim.ray_fast_group(pos, vec, grp, 1, body)
            if dist >= 0:
                obs[i] = np.exp(-dist)
        return obs

    def obs_lidar_pseudo(self, positions):
        '''
        Return a robot-centric lidar observation of a list of positions.

        Lidar is a set of bins around the robot (divided evenly in a circle).
        The detection directions are exclusive and exhaustive for a full 360 view.
        Each bin reads 0 if there are no objects in that direction.
        If there are multiple objects, the distance to the closest one is used.
        Otherwise the bin reads the fraction of the distance towards the robot.

        E.g. if the object is 90% of lidar_max_dist away, the bin will read 0.1,
        and if the object is 10% of lidar_max_dist away, the bin will read 0.9.
        (The reading can be thought of as "closeness" or inverse distance)

        This encoding has some desirable properties:
            - bins read 0 when empty
            - bins smoothly increase as objects get close
            - maximum reading is 1.0 (where the object overlaps the robot)
            - close objects occlude far objects
            - constant size observation with variable numbers of objects
        '''
        obs = np.zeros(self.lidar_num_bins)
        for pos in positions:
            pos = np.asarray(pos)
            if pos.shape == (3,):
                pos = pos[:2]  
            z = np.complex128(complex(*self.ego_xy(pos)))  
            dist = np.abs(z)
            angle = np.angle(z) % (np.pi * 2)
            bin_size = (np.pi * 2) / self.lidar_num_bins
            bin = int(angle / bin_size)
            bin_angle = bin_size * bin
            if self.lidar_max_dist is None:
                sensor = np.exp(-self.lidar_exp_gain * dist)
            else:
                sensor = max(0, self.lidar_max_dist - dist) / self.lidar_max_dist
            obs[bin] = max(obs[bin], sensor)
            
            if self.lidar_alias:
                alias = (angle - bin_angle) / bin_size
                assert 0 <= alias <= 1, f'bad alias {alias}, dist {dist}, angle {angle}, bin {bin}'
                bin_plus = (bin + 1) % self.lidar_num_bins
                bin_minus = (bin - 1) % self.lidar_num_bins
                obs[bin_plus] = max(obs[bin_plus], alias * sensor)
                obs[bin_minus] = max(obs[bin_minus], (1 - alias) * sensor)
        return obs

    
    def obs(self):
        ''' Return the observation of our agent '''
        self.sim.forward()  
        obs = {}

        if self.observe_goal_dist:  
            obs['goal_dist'] = np.array([np.exp(-self.dist_goal())])
        if self.observe_goal_comp:  
            obs['goal_compass'] = self.obs_compass(self.goal_pos)
        if self.observe_goal_lidar:  
            obs['goal_lidar'] = self.obs_lidar([self.goal_pos], GROUP_GOAL)
        if self.task == 'push':  
            box_pos = self.box_pos
            if self.observe_box_comp:
                obs['box_compass'] = self.obs_compass(box_pos)
            if self.observe_box_lidar:
                obs['box_lidar'] = self.obs_lidar([box_pos], GROUP_BOX)
        if self.task == 'circle' and self.observe_circle:
            obs['circle_lidar'] = self.obs_lidar([self.goal_pos], GROUP_CIRCLE)
        if self.observe_freejoint:  
            joint_id = self.model.joint_name2id('robot')
            joint_qposadr = self.model.jnt_qposadr[joint_id]
            assert joint_qposadr == 0  
            obs['freejoint'] = self.data.qpos[:7]
        if self.observe_com:  
            obs['com'] = self.world.robot_com()
        if self.observe_sensors:  
            
            for sensor in self.sensors_obs:
                
                obs[sensor] = self.world.get_sensor(sensor)
            for sensor in self.robot.hinge_vel_names:
                obs[sensor] = self.world.get_sensor(sensor)
            for sensor in self.robot.ballangvel_names:
                obs[sensor] = self.world.get_sensor(sensor)
            
            if self.sensors_angle_components:  
                for sensor in self.robot.hinge_pos_names:
                    theta = float(self.world.get_sensor(sensor))  
                    obs[sensor] = np.array([np.sin(theta), np.cos(theta)])
                for sensor in self.robot.ballquat_names:
                    quat = self.world.get_sensor(sensor)
                    obs[sensor] = quat2mat(quat)
            else:  
                for sensor in self.robot.hinge_pos_names:
                    obs[sensor] = self.world.get_sensor(sensor)
                for sensor in self.robot.ballquat_names:
                    obs[sensor] = self.world.get_sensor(sensor)
        if self.observe_remaining:  
            obs['remaining'] = np.array([self.steps / self.num_steps])
            assert 0.0 <= obs['remaining'][0] <= 1.0, 'bad remaining {}'.format(obs['remaining'])
        
        if self.walls_num and self.observe_walls:
            obs['walls_lidar'] = self.obs_lidar(self.walls_pos, GROUP_WALL)
        if self.observe_hazards:
            obs['hazards_lidar'] = self.obs_lidar(self.hazards_pos, GROUP_HAZARD)
        if self.observe_vases:
            obs['vases_lidar'] = self.obs_lidar(self.vases_pos, GROUP_VASE)
        if self.gremlins_num and self.observe_gremlins:
            obs['gremlins_lidar'] = self.obs_lidar(self.gremlins_obj_pos, GROUP_GREMLIN)
        if self.pillars_num and self.observe_pillars:
            obs['pillars_lidar'] = self.obs_lidar(self.pillars_pos, GROUP_PILLAR)
        if self.buttons_num and self.observe_buttons:
            
            if self.buttons_timer == 0:
                obs['buttons_lidar'] = self.obs_lidar(self.buttons_pos, GROUP_BUTTON)
            else:
                obs['buttons_lidar'] = np.zeros(self.lidar_num_bins)
        
        if self.observe_qpos:
            obs['qpos'] = self.data.qpos.copy()
        if self.observe_qvel:
            obs['qvel'] = self.data.qvel.copy()
        if self.observe_ctrl:
            obs['ctrl'] = self.data.ctrl.copy()
        if self.observe_vision:  
            obs['vision'] = self.obs_vision()
        if self.observation_flatten:  
            flat_obs = np.zeros(self.obs_flat_size)
            offset = 0
            sort_list = list(sorted(self.obs_space_dict.keys()))
            sort_list.remove("com")
            sort_list.insert(0, "com")
            for k in sort_list:
                k_size = np.prod(obs[k].shape)
                flat_obs[offset:offset + k_size] = obs[k].flat
                offset += k_size
            obs = flat_obs
        assert self.observation_space.contains(obs), f'Bad obs {obs} {self.observation_space}'

        self.last_coords_robot_com = self.coords_robot_com
        self.coords_robot_com = self.world.robot_com().copy()[:2]

        return obs

    
    def cost(self):
        ''' Calculate the current costs and return a dict '''
        self.sim.forward()  
        cost = {}
        
        if self.constrain_vases:
            cost['cost_vases_contact'] = 0
        if self.constrain_pillars:
            cost['cost_pillars'] = 0
        if self.constrain_buttons:
            cost['cost_buttons'] = 0
        if self.constrain_gremlins:
            cost['cost_gremlins'] = 0
        buttons_constraints_active = self.constrain_buttons and (self.buttons_timer == 0)
        for contact in self.data.contact[:self.data.ncon]:
            geom_ids = [contact.geom1, contact.geom2]
            geom_names = sorted([self.model.geom_id2name(g) for g in geom_ids])
            if self.constrain_vases and any(n.startswith('vase') for n in geom_names):
                if any(n in self.robot.geom_names for n in geom_names):
                    cost['cost_vases_contact'] += self.vases_contact_cost
            if self.constrain_pillars and any(n.startswith('pillar') for n in geom_names):
                if any(n in self.robot.geom_names for n in geom_names):
                    cost['cost_pillars'] += self.pillars_cost
            if buttons_constraints_active and any(n.startswith('button') for n in geom_names):
                if any(n in self.robot.geom_names for n in geom_names):
                    if not any(n == f'button{self.goal_button}' for n in geom_names):
                        cost['cost_buttons'] += self.buttons_cost
            if self.constrain_gremlins and any(n.startswith('gremlin') for n in geom_names):
                if any(n in self.robot.geom_names for n in geom_names):
                    cost['cost_gremlins'] += self.gremlins_contact_cost

        
        if self.constrain_vases and self.vases_displace_cost:
            cost['cost_vases_displace'] = 0
            for i in range(self.vases_num):
                name = f'vase{i}'
                dist = np.sqrt(np.sum(np.square(self.data.get_body_xpos(name)[:2] - self.reset_layout[name])))
                if dist > self.vases_displace_threshold:
                    cost['cost_vases_displace'] += dist * self.vases_displace_cost

        
        if self.constrain_vases and self.vases_velocity_cost:
            
            cost['cost_vases_velocity'] = 0
            for i in range(self.vases_num):
                name = f'vase{i}'
                vel = np.sqrt(np.sum(np.square(self.data.get_body_xvelp(name))))
                if vel >= self.vases_velocity_threshold:
                    cost['cost_vases_velocity'] += vel * self.vases_velocity_cost

        
        if self.constrain_hazards:
            cost['cost_hazards'] = 0
            for h_pos in self.hazards_pos:
                h_dist = self.dist_xy(h_pos)
                if h_dist <= self.hazards_size:
                    cost['cost_hazards'] += self.hazards_cost * (self.hazards_size - h_dist)

        
        cost['cost'] = sum(v for k, v in cost.items() if k.startswith('cost_'))

        
        if self.constrain_indicator:
            for k in list(cost.keys()):
                cost[k] = float(cost[k] > 0.0)  

        self._cost = cost

        return cost

    def goal_met(self):
        ''' Return true if the current goal is met this step '''
        if self.task == 'goal':
            return self.dist_goal() <= self.goal_size
        if self.task == 'push':
            return self.dist_box_goal() <= self.goal_size
        if self.task == 'button':
            for contact in self.data.contact[:self.data.ncon]:
                geom_ids = [contact.geom1, contact.geom2]
                geom_names = sorted([self.model.geom_id2name(g) for g in geom_ids])
                if any(n == f'button{self.goal_button}' for n in geom_names):
                    if any(n in self.robot.geom_names for n in geom_names):
                        return True
            return False
        if self.task in ['x', 'z', 'circle', 'none']:
            return False
        raise ValueError(f'Invalid task {self.task}')

    def set_mocaps(self):
        ''' Set mocap object positions before a physics step is executed '''
        if self.gremlins_num: 
            phase = float(self.data.time)
            for i in range(self.gremlins_num):
                name = f'gremlin{i}'
                target = np.array([np.sin(phase), np.cos(phase)]) * self.gremlins_travel
                pos = np.r_[target, [self.gremlins_size]]
                self.data.set_mocap_pos(name + 'mocap', pos)

    def update_layout(self):
        ''' Update layout dictionary with new places of objects '''
        self.sim.forward()
        for k in list(self.layout.keys()):
            
            if 'gremlin' in k:
                continue
            self.layout[k] = self.data.get_body_xpos(k)[:2].copy()

    def buttons_timer_tick(self):
        ''' Tick the buttons resampling timer '''
        self.buttons_timer = max(0, self.buttons_timer - 1)

    
    def step(self, action, render=False):
        ''' Take a step and return observation, reward, done, and info '''
        action = np.array(action, copy=False)  
        assert not self.done, 'Environment must be reset before stepping'

        info = {}

        
        action_range = self.model.actuator_ctrlrange
        
        self.data.ctrl[:] = np.clip(action, action_range[:,0], action_range[:,1]) 
        if self.action_noise:
            self.data.ctrl[:] += self.action_noise * self.rs.randn(self.model.nu)

        
        exception = False
        for _ in range(self.rs.binomial(self.frameskip_binom_n, self.frameskip_binom_p)):
            try:
                self.set_mocaps()
                self.sim.step()  
            except MujocoException as me:
                print('MujocoException', me)
                exception = True
                break
        if exception:
            self.done = True
            reward = self.reward_exception
            info['cost_exception'] = 1.0
        else:
            self.sim.forward()  

            
            reward = self.reward()

            
            info.update(self.cost())

            
            self.buttons_timer_tick()

            
            if self.goal_met():
                info['goal_met'] = True
                reward += self.reward_goal
                if self.continue_goal:
                    
                    self.update_layout()
                    
                    self.buttons_timer = self.buttons_resampling_delay
                    
                    if self.terminate_resample_failure:
                        try:
                            self.build_goal()
                        except ResamplingError as e:
                            
                            self.done = True
                    else:
                        
                        self.build_goal()
                else:
                    self.done = False 
            else:
                info['goal_met'] = False

        
        self.steps += 1
        if self.steps >= self.num_steps:
            self.done = True  

        if render:
            frame = self.render(mode='rgb_array', camera_id=0)  
            info['render'] = frame.transpose(2, 0, 1)

        
        info['coordinates'] = self.last_coords_robot_com
        info['next_coordinates'] = self.coords_robot_com
        return self.obs(), reward, self.done, info

    
    def reward(self):
        ''' Calculate the dense component of reward.  Call exactly once per step '''
        reward = 0.0
        
        if self.task in ['goal', 'button']:
            dist_goal = self.dist_goal()
            reward += (self.last_dist_goal - dist_goal) * self.reward_distance
            self.last_dist_goal = dist_goal
        
        if self.task == 'push':
            dist_box = self.dist_box()
            gate_dist_box_reward = (self.last_dist_box > self.box_null_dist * self.box_size)
            reward += (self.last_dist_box - dist_box) * self.reward_box_dist * gate_dist_box_reward
            self.last_dist_box = dist_box
        
        if self.task == 'push':
            dist_box_goal = self.dist_box_goal()
            reward += (self.last_box_goal - dist_box_goal) * self.reward_box_goal
            self.last_box_goal = dist_box_goal
        
        if self.task == 'x':
            robot_com = self.world.robot_com()
            reward += (robot_com[0] - self.last_robot_com[0]) * self.reward_x
            self.last_robot_com = robot_com
        
        if self.task == 'z':
            robot_com = self.world.robot_com()
            reward += (robot_com[2] - self.last_robot_com[2]) * self.reward_z
            self.last_robot_com = robot_com
        
        if self.task == 'circle':
            robot_com = self.world.robot_com()
            robot_vel = self.world.robot_vel()
            x, y, _ = robot_com
            u, v, _ = robot_vel
            radius = np.sqrt(x**2 + y**2)
            reward += (((-u*y + v*x)/radius)/(1 + np.abs(radius - self.circle_radius))) * self.reward_circle
        
        if self.reward_orientation:
            zalign = quat2zalign(self.data.get_body_xquat(self.reward_orientation_body))
            reward += self.reward_orientation_scale * zalign
        
        if self.reward_clip:
            in_range = reward < self.reward_clip and reward > -self.reward_clip
            if not(in_range):
                reward = np.clip(reward, -self.reward_clip, self.reward_clip)
                print('Warning: reward was outside of range!')
        return reward

    def render_lidar(self, poses, color, offset, group):
        ''' Render the lidar observation '''
        robot_pos = self.world.robot_pos()
        robot_mat = self.world.robot_mat()
        lidar = self.obs_lidar(poses, group)
        for i, sensor in enumerate(lidar):
            if self.lidar_type == 'pseudo':
                i += 0.5  
            theta = 2 * np.pi * i / self.lidar_num_bins
            rad = self.render_lidar_radius
            binpos = np.array([np.cos(theta) * rad, np.sin(theta) * rad, offset])
            pos = robot_pos + np.matmul(binpos, robot_mat.transpose())
            alpha = min(1, sensor + .1)
            self.viewer.add_marker(pos=pos,
                                   size=self.render_lidar_size * np.ones(3),
                                   type=const.GEOM_SPHERE,
                                   rgba=np.array(color) * alpha,
                                   label='')

    def render_compass(self, pose, color, offset):
        ''' Render a compass observation '''
        robot_pos = self.world.robot_pos()
        robot_mat = self.world.robot_mat()
        
        compass = np.concatenate([self.obs_compass(pose)[:2] * 0.15, [offset]])
        pos = robot_pos + np.matmul(compass, robot_mat.transpose())
        self.viewer.add_marker(pos=pos,
                               size=.05 * np.ones(3),
                               type=const.GEOM_SPHERE,
                               rgba=np.array(color) * 0.5,
                               label='')

    def render_area(self, pos, size, color, label='', alpha=0.1):
        ''' Render a radial area in the environment '''
        z_size = min(size, 0.3)
        pos = np.asarray(pos)
        if pos.shape == (2,):
            pos = np.r_[pos, 0]  
        self.viewer.add_marker(pos=pos,
                               size=[size, size, z_size],
                               type=const.GEOM_CYLINDER,
                               rgba=np.array(color) * alpha,
                               label=label if self.render_labels else '')

    def render_sphere(self, pos, size, color, label='', alpha=0.1):
        ''' Render a radial area in the environment '''
        pos = np.asarray(pos)
        if pos.shape == (2,):
            pos = np.r_[pos, 0]  
        self.viewer.add_marker(pos=pos,
                               size=size * np.ones(3),
                               type=const.GEOM_SPHERE,
                               rgba=np.array(color) * alpha,
                               label=label if self.render_labels else '')

    def render_swap_callback(self):
        ''' Callback between mujoco render and swapping GL buffers '''
        if self.observe_vision and self.vision_render:
            self.viewer.draw_pixels(self.save_obs_vision, 0, 0)

    def render(self,
               mode='human', 
               camera_id=None,
               width=DEFAULT_WIDTH,
               height=DEFAULT_HEIGHT
               ):
        ''' Render the environment to the screen '''

        if self.viewer is None or mode!=self._old_render_mode:
            
            if mode == 'human':
                self.viewer = MjViewer(self.sim)
                self.viewer.cam.fixedcamid = -1
                self.viewer.cam.type = const.CAMERA_FREE
            else:
                self.viewer = MjRenderContextOffscreen(self.sim)
                self.viewer._hide_overlay = True
                self.viewer.cam.fixedcamid = camera_id 
                self.viewer.cam.type = const.CAMERA_FIXED
            self.viewer.render_swap_callback = self.render_swap_callback
            
            self.viewer.vopt.geomgroup[:] = 1
            self._old_render_mode = mode
        self.viewer.update_sim(self.sim)

        if camera_id is not None:
            
            self.viewer.cam.fixedcamid = camera_id

        
        if self.render_lidar_markers:
            offset = self.render_lidar_offset_init  
            if 'box_lidar' in self.obs_space_dict or 'box_compass' in self.obs_space_dict:
                if 'box_lidar' in self.obs_space_dict:
                    self.render_lidar([self.box_pos], COLOR_BOX, offset, GROUP_BOX)
                if 'box_compass' in self.obs_space_dict:
                    self.render_compass(self.box_pos, COLOR_BOX, offset)
                offset += self.render_lidar_offset_delta
            if 'goal_lidar' in self.obs_space_dict or 'goal_compass' in self.obs_space_dict:
                if 'goal_lidar' in self.obs_space_dict:
                    self.render_lidar([self.goal_pos], COLOR_GOAL, offset, GROUP_GOAL)
                if 'goal_compass' in self.obs_space_dict:
                    self.render_compass(self.goal_pos, COLOR_GOAL, offset)
                offset += self.render_lidar_offset_delta
            if 'buttons_lidar' in self.obs_space_dict:
                self.render_lidar(self.buttons_pos, COLOR_BUTTON, offset, GROUP_BUTTON)
                offset += self.render_lidar_offset_delta
            if 'circle_lidar' in self.obs_space_dict:
                self.render_lidar([ORIGIN_COORDINATES], COLOR_CIRCLE, offset, GROUP_CIRCLE)
                offset += self.render_lidar_offset_delta
            if 'walls_lidar' in self.obs_space_dict:
                self.render_lidar(self.walls_pos, COLOR_WALL, offset, GROUP_WALL)
                offset += self.render_lidar_offset_delta
            if 'hazards_lidar' in self.obs_space_dict:
                self.render_lidar(self.hazards_pos, COLOR_HAZARD, offset, GROUP_HAZARD)
                offset += self.render_lidar_offset_delta
            if 'pillars_lidar' in self.obs_space_dict:
                self.render_lidar(self.pillars_pos, COLOR_PILLAR, offset, GROUP_PILLAR)
                offset += self.render_lidar_offset_delta
            if 'gremlins_lidar' in self.obs_space_dict:
                self.render_lidar(self.gremlins_obj_pos, COLOR_GREMLIN, offset, GROUP_GREMLIN)
                offset += self.render_lidar_offset_delta
            if 'vases_lidar' in self.obs_space_dict:
                self.render_lidar(self.vases_pos, COLOR_VASE, offset, GROUP_VASE)
                offset += self.render_lidar_offset_delta

        
        if self.task == 'button':
            self.render_area(self.goal_pos, self.buttons_size * 2, COLOR_BUTTON, 'goal', alpha=0.1)

        
        if self._cost.get('cost', 0) > 0:
            self.render_sphere(self.world.robot_pos(), 0.25, COLOR_RED, alpha=.5)

        
        if self.observe_vision and self.vision_render:
            vision = self.obs_vision()
            vision = np.array(vision * 255, dtype='uint8')
            vision = Image.fromarray(vision).resize(self.vision_render_size)
            vision = np.array(vision, dtype='uint8')
            self.save_obs_vision = vision

        if mode=='human':
            self.viewer.render()
        elif mode=='rgb_array':
            self.viewer.render(width, height)
            data = self.viewer.read_pixels(width, height, depth=False)
            self.viewer._markers[:] = []
            self.viewer._overlay.clear()
            return data[::-1, :, :]

    
    def plot_trajectory(self, trajectory, color, ax):
        ax.plot(trajectory[:, 0], trajectory[:, 1], color=color, linewidth=0.7)

    def plot_trajectories(self, trajectories, colors, plot_axis, ax):
        
        square_axis_limit = 0.0
        for trajectory, color in zip(trajectories, colors):
            trajectory = np.array(trajectory)
            self.plot_trajectory(trajectory, color, ax)

            square_axis_limit = max(square_axis_limit, np.max(np.abs(trajectory[:, :2])))
        square_axis_limit = square_axis_limit * 1.2

        if plot_axis == 'free':
            return

        if plot_axis is None:
            plot_axis = [-square_axis_limit, square_axis_limit, -square_axis_limit, square_axis_limit]

        if plot_axis is not None:
            ax.axis(plot_axis)
            ax.set_aspect('equal')
        else:
            ax.axis('scaled')

    def render_trajectories(self, trajectories, colors, plot_axis, ax):
        coordinates_trajectories = self._get_coordinates_trajectories(trajectories)
        self.plot_trajectories(coordinates_trajectories, colors, plot_axis, ax)

    def _get_coordinates_trajectories(self, trajectories):
        coordinates_trajectories = []
        for trajectory in trajectories:
            if trajectory['env_infos']['coordinates'].dtype == object:
                coordinates_trajectories.append(np.concatenate([
                    np.concatenate(trajectory['env_infos']['coordinates'], axis=0),
                    [trajectory['env_infos']['next_coordinates'][-1][-1]],
                ]))
            elif trajectory['env_infos']['coordinates'].ndim == 2:
                coordinates_trajectories.append(np.concatenate([
                    trajectory['env_infos']['coordinates'],
                    [trajectory['env_infos']['next_coordinates'][-1]]
                ]))
            elif trajectory['env_infos']['coordinates'].ndim > 2:
                coordinates_trajectories.append(np.concatenate([
                    trajectory['env_infos']['coordinates'].reshape(-1, 2),
                    trajectory['env_infos']['next_coordinates'].reshape(-1, 2)[-1:]
                ]))
            else:
                assert False
        return coordinates_trajectories

    def calc_eval_metrics(self, trajectories, is_option_trajectories):
        
        eval_metrics = {}
        return eval_metrics


