"""Environment class."""

import os
import random
import re
import string
import tempfile
import time

import cv2
import gymnasium as gym
import imageio
import numpy as np
import pybullet as p

from environments import cameras
from utils import pybullet_utils
from utils import general_utils as utils

PLACE_STEP = 0.0003
PLACE_DELTA_THRESHOLD = 0.005

UR5_URDF_PATH = "ur5/ur5.urdf"
UR5_WORKSPACE_URDF_PATH = "ur5/workspace.urdf"
PLANE_URDF_PATH = "plane/plane.urdf"


class Environment(gym.Env):
    """OpenAI Gym-style environment class."""

    def __init__(
        self,
        assets_root,
        task=None,
        disp=False, # for cap-options: True
        shared_memory=False,
        hz=240,
        record_cfg=None,
    ):
        """Creates OpenAI Gym-style environment with PyBullet.

        Args:
          assets_root: root directory of assets.
          task: the task to use. If None, the user must call set_task for the
            environment to work properly.
          disp: show environment with PyBullet's built-in display viewer.
          shared_memory: run with shared memory.
          hz: PyBullet physics simulation step speed. Set to 480 for deformables.

        Raises:
          RuntimeError: if pybullet cannot load fileIOPlugin.
        """
        self.pix_size = 0.003125
        self.obj_ids = {"fixed": [], "rigid": [], "deformable": []}
        self.objects = self.obj_ids  # make a copy

        self.homej = np.array([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi
        self.agent_cams = cameras.RealSenseD415.CONFIG
        self.record_cfg = record_cfg
        self.save_video = False
        self.video_path = None
        self.step_counter = 0

        self.assets_root = assets_root

        color_tuple = [
            gym.spaces.Box(0, 255, config["image_size"] + (3,), dtype=np.uint8)
            for config in self.agent_cams
        ]
        depth_tuple = [
            gym.spaces.Box(0.0, 20.0, config["image_size"], dtype=np.float32)
            for config in self.agent_cams
        ]
        self.observation_space = gym.spaces.Dict(
            {
                "color": gym.spaces.Tuple(color_tuple),
                "depth": gym.spaces.Tuple(depth_tuple),
            }
        )
        self.position_bounds = gym.spaces.Box(
            low=np.array([0.25, -0.5, 0.0], dtype=np.float32),
            high=np.array([0.75, 0.5, 0.28], dtype=np.float32),
            shape=(3,),
            dtype=np.float32,
        )
        self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.3]])

        self.action_space = gym.spaces.Dict(
            {
                "pose0": gym.spaces.Tuple(
                    (
                        self.position_bounds,
                        gym.spaces.Box(-1.0, 1.0, shape=(4,), dtype=np.float32),
                    )
                ),
                "pose1": gym.spaces.Tuple(
                    (
                        self.position_bounds,
                        gym.spaces.Box(-1.0, 1.0, shape=(4,), dtype=np.float32),
                    )
                ),
            }
        )

        # Start PyBullet.
        disp_option = p.DIRECT
        if disp:
            disp_option = p.GUI
            if shared_memory:
                disp_option = p.SHARED_MEMORY
        client = p.connect(disp_option)
        file_io = p.loadPlugin("fileIOPlugin", physicsClientId=client)
        if file_io < 0:
            raise RuntimeError("pybullet: cannot load FileIO!")
        if file_io >= 0:
            p.executePluginCommand(
                file_io,
                textArgument=assets_root,
                intArgs=[p.AddFileIOAction],
                physicsClientId=client,
            )

        p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
        p.setPhysicsEngineParameter(enableFileCaching=0)
        p.setAdditionalSearchPath(assets_root)
        p.setAdditionalSearchPath(tempfile.gettempdir())
        p.setTimeStep(1.0 / hz)

        # If using --disp, move default camera closer to the scene.
        if disp:
            target = p.getDebugVisualizerCamera()[11]
            p.resetDebugVisualizerCamera(
                cameraDistance=1.1,
                cameraYaw=90,
                cameraPitch=-25,
                cameraTargetPosition=target,
            )

        if task:
            self.set_task(task)

    def __del__(self):
        if hasattr(self, "video_writer"):
            self.video_writer.close()

    @property
    def is_static(self):
        """Return true if objects are no longer moving."""
        v = [np.linalg.norm(p.getBaseVelocity(i)[0]) for i in self.obj_ids["rigid"]]
        return all(np.array(v) < 5e-3)

    def fill_dummy_template(self, template):
        """check if there are empty templates that haven't been fulfilled yet. if so. fill in dummy numbers"""
        full_template_path = os.path.join(self.assets_root, template)
        with open(full_template_path, "r") as file:
            fdata = file.read()

        fill = False
        for field in ["DIMH", "DIMR", "DIMX", "DIMY", "DIMZ", "DIM"]:
            # usually 3 should be enough
            if field in fdata:
                default_replace_vals = np.random.uniform(
                    0.03, 0.05, size=(3,)
                ).tolist()  # [0.03,0.03,0.03]
                for i in range(len(default_replace_vals)):
                    fdata = fdata.replace(f"{field}{i}", str(default_replace_vals[i]))
                fill = True

        for field in ["HALF"]:
            # usually 3 should be enough
            if field in fdata:
                default_replace_vals = np.random.uniform(
                    0.01, 0.03, size=(3,)
                ).tolist()  # [0.015,0.015,0.015]
                for i in range(len(default_replace_vals)):
                    fdata = fdata.replace(f"{field}{i}", str(default_replace_vals[i]))
                fill = True

        if fill:
            alphabet = string.ascii_lowercase + string.digits
            rname = "".join(random.choices(alphabet, k=16))
            tmpdir = tempfile.gettempdir()
            template_filename = os.path.split(template)[-1]
            fname = os.path.join(tmpdir, f"{template_filename}.{rname}")
            with open(fname, "w") as file:
                file.write(fdata)
            # print("fill-in dummys")

            return fname
        else:
            return template

    def add_sdf_object(self, sdf):
        import pathlib

        path = pathlib.Path(self.assets_root)
        sdf_path = path.parent / "other_assets" / sdf
        obj_ids = pybullet_utils.load_sdf(p, str(sdf_path))

        return obj_ids

    def add_object(self, urdf, pose, category="rigid", color=None, scale=1, **kwargs):
        """List of (fixed, rigid, or deformable) objects in env."""
        fixed_base = 1 if category == "fixed" else 0

        if "template" in urdf:
            if not os.path.exists(os.path.join(self.assets_root, urdf)):
                urdf = urdf.replace("-template", "")

            urdf = self.fill_dummy_template(urdf)

        if not os.path.exists(os.path.join(self.assets_root, urdf)):
            print(
                f"missing urdf error: {os.path.join(self.assets_root, urdf)}. use dummy block."
            )
            urdf = "stacking/block.urdf"

        if len(pose) == 3 and (not hasattr(pose[0], "__len__")):
            # add default orientation if missing
            pose = (pose, (0, 0, 0, 1))

        obj_id = pybullet_utils.load_urdf(
            p,
            os.path.join(self.assets_root, urdf),
            pose[0],
            pose[1],
            globalScaling=scale,
            useFixedBase=fixed_base,
        )

        if not obj_id is None:
            self.obj_ids[category].append(obj_id)

        if color is not None:
            if type(color) is str:
                color = utils.COLORS[color]
            color = color + [1.0]
            p.changeVisualShape(obj_id, -1, rgbaColor=color)

        if (
            hasattr(self, "record_cfg")
            and "blender_render" in self.record_cfg
            and self.record_cfg["blender_render"]
        ):
            # print("urdf:", os.path.join(self.assets_root, urdf))
            # if color is None:
            #     color = (0.5,0.5,0.5,1) # by default
            print("color:", color)

            self.blender_recorder.register_object(
                obj_id, os.path.join(self.assets_root, urdf), color=color
            )

        return obj_id

    def set_color(self, obj_id, color):
        p.changeVisualShape(obj_id, -1, rgbaColor=color + [1])

    def set_object_color(self, *args, **kwargs):
        return self.set_color(*args, **kwargs)

    # ---------------------------------------------------------------------------
    # Standard Gym Functions
    # ---------------------------------------------------------------------------

    def seed(self, seed=None):
        self._random = np.random.RandomState(seed)
        self._seed = seed if seed else -100
        return seed

    def reset(self):
        """Performs common reset functionality for all supported tasks."""
        if not self.task:
            raise ValueError(
                "environment task must be set. Call set_task or pass "
                "the task arg in the environment constructor."
            )
        self.obj_ids = {"fixed": [], "rigid": [], "deformable": []}
        p.resetSimulation(p.RESET_USE_DEFORMABLE_WORLD)
        p.setGravity(0, 0, -9.8)

        # Temporarily disable rendering to load scene faster.
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0)

        plane = pybullet_utils.load_urdf(
            p, os.path.join(self.assets_root, PLANE_URDF_PATH), [0, 0, -0.001]
        )
        workspace = pybullet_utils.load_urdf(
            p, os.path.join(self.assets_root, UR5_WORKSPACE_URDF_PATH), [0.5, 0, 0]
        )

        # Load UR5 robot arm equipped with suction end effector.
        # TODO(andyzeng): add back parallel-jaw grippers.
        self.ur5 = pybullet_utils.load_urdf(
            p, os.path.join(self.assets_root, UR5_URDF_PATH)
        )
        self.ee = self.task.ee(self.assets_root, self.ur5, 9, self.obj_ids)
        self.ee_tip = 10  # Link ID of suction cup.

        if (
            hasattr(self, "record_cfg")
            and "blender_render" in self.record_cfg
            and self.record_cfg["blender_render"]
        ):
            from misc.pyBulletSimRecorder import PyBulletRecorder

            self.blender_recorder = PyBulletRecorder()

            self.blender_recorder.register_object(
                plane, os.path.join(self.assets_root, PLANE_URDF_PATH)
            )
            self.blender_recorder.register_object(
                workspace, os.path.join(self.assets_root, UR5_WORKSPACE_URDF_PATH)
            )
            self.blender_recorder.register_object(
                self.ur5, os.path.join(self.assets_root, UR5_URDF_PATH)
            )

            self.blender_recorder.register_object(self.ee.base, self.ee.base_urdf_path)
            if hasattr(self.ee, "body"):
                self.blender_recorder.register_object(self.ee.body, self.ee.urdf_path)

        # Get revolute joint indices of robot (skip fixed joints).
        n_joints = p.getNumJoints(self.ur5)
        joints = [p.getJointInfo(self.ur5, i) for i in range(n_joints)]
        self.joints = [j[0] for j in joints if j[2] == p.JOINT_REVOLUTE]

        # Move robot to home joint configuration.
        for i in range(len(self.joints)):
            p.resetJointState(self.ur5, self.joints[i], self.homej[i])

        # Reset end effector.
        self.ee.release()

        # Reset task.
        self.task.reset(self)

        # Re-enable rendering.
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1)

        obs, _, _, _ = self.step()
        return obs

    def step(self, action=None):
        """Execute action with specified primitive.

        Args:
          action: action to execute.

        Returns:
          (obs, reward, done, info) tuple containing MDP step data.
        """
        if action is not None:
            timeout = self.task.primitive(
                self.movej, self.movep, self.ee, action["pose0"], action["pose1"]
            )

            # Exit early if action times out. We still return an observation
            # so that we don't break the Gym API contract.
            if timeout:
                obs = {"color": (), "depth": ()}
                for config in self.agent_cams:
                    color, depth, _ = self.render_camera(config)
                    obs["color"] += (color,)
                    obs["depth"] += (depth,)
                return obs, 0.0, True, self.info

        start_time = time.time()
        # Step simulator asynchronously until objects settle.
        while not self.is_static:
            self.step_simulation()
            if time.time() - start_time > 5:  # timeout
                break

        # Get task rewards.
        reward, info = self.task.reward() if action is not None else (0, {})
        done = self.task.done()

        # Add ground truth robot state into info.
        info.update(self.info)

        obs = self._get_obs()

        return obs, reward, done, info

    def step_simulation(self):
        p.stepSimulation()
        self.step_counter += 1

        if self.save_video and self.step_counter % 5 == 0:
            self.add_video_frame()

    def render(self, mode="rgb_array"):
        # Render only the color image from the first camera.
        # Only support rgb_array for now.
        if mode != "rgb_array":
            raise NotImplementedError("Only rgb_array implemented")
        color, depth, segm = self.render_camera(self.agent_cams[0])
        return color

    def render_camera(self, config, image_size=None, shadow=1):
        """Render RGB-D image with specified camera configuration."""
        if not image_size:
            image_size = config["image_size"]

        # OpenGL camera settings.
        lookdir = np.float32([0, 0, 1]).reshape(3, 1)
        updir = np.float32([0, -1, 0]).reshape(3, 1)
        rotation = p.getMatrixFromQuaternion(config["rotation"])
        rotm = np.float32(rotation).reshape(3, 3)
        lookdir = (rotm @ lookdir).reshape(-1)
        updir = (rotm @ updir).reshape(-1)
        lookat = config["position"] + lookdir
        focal_len = config["intrinsics"][0]
        znear, zfar = config["zrange"]
        viewm = p.computeViewMatrix(config["position"], lookat, updir)
        fovh = (image_size[0] / 2) / focal_len
        fovh = 180 * np.arctan(fovh) * 2 / np.pi

        # Notes: 1) FOV is vertical FOV 2) aspect must be float
        aspect_ratio = image_size[1] / image_size[0]
        projm = p.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar)

        # Render with OpenGL camera settings.
        _, _, color, depth, segm = p.getCameraImage(
            width=image_size[1],
            height=image_size[0],
            viewMatrix=viewm,
            projectionMatrix=projm,
            shadow=shadow,
            flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX,
            renderer=p.ER_BULLET_HARDWARE_OPENGL,
        )

        # Get color image.
        color_image_size = (image_size[0], image_size[1], 4)
        color = np.array(color, dtype=np.uint8).reshape(color_image_size)
        color = color[:, :, :3]  # remove alpha channel
        if config["noise"]:
            color = np.int32(color)
            color += np.int32(self._random.normal(0, 3, image_size))
            color = np.uint8(np.clip(color, 0, 255))

        # Get depth image.
        depth_image_size = (image_size[0], image_size[1])
        zbuffer = np.array(depth).reshape(depth_image_size)
        depth = zfar + znear - (2.0 * zbuffer - 1.0) * (zfar - znear)
        depth = (2.0 * znear * zfar) / depth
        if config["noise"]:
            depth += self._random.normal(0, 0.003, depth_image_size)

        # Get segmentation image.
        segm = np.uint8(segm).reshape(depth_image_size)

        return color, depth, segm

    @property
    def info(self):
        """Environment info variable with object poses, dimensions, and colors."""

        # Some tasks create and remove zones, so ignore those IDs.
        # removed_ids = []
        # if (isinstance(self.task, tasks.names['cloth-flat-notarget']) or
        #         isinstance(self.task, tasks.names['bag-alone-open'])):
        #   removed_ids.append(self.task.zone_id)

        info = {}  # object id : (position, rotation, dimensions)
        for obj_ids in self.obj_ids.values():
            for obj_id in obj_ids:
                pos, rot = p.getBasePositionAndOrientation(obj_id)
                dim = p.getVisualShapeData(obj_id)[0][3]
                info[obj_id] = (pos, rot, dim)

        info["lang_goal"] = self.get_lang_goal()
        return info

    def set_task(self, task):
        task.set_assets_root(self.assets_root)
        self.task = task

    def get_task_name(self):
        return type(self.task).__name__

    def get_lang_goal(self):
        if self.task:
            return self.task.get_lang_goal()
        else:
            raise Exception("No task for was set")

    # ---------------------------------------------------------------------------
    # Robot Movement Functions
    # ---------------------------------------------------------------------------

    def movej(self, targj, speed=0.01, timeout=5, effector=None):
        """Move UR5 to target joint configuration."""
        if self.save_video:
            timeout = timeout * 30  # 50?

        if not effector:
            body = self.ur5
            n_joints = p.getNumJoints(body)
            joints = [p.getJointInfo(body, i) for i in range(n_joints)]
            joints = [j[0] for j in joints if j[2] == p.JOINT_REVOLUTE]
            t0 = time.time()
            while (time.time() - t0) < timeout:
                currj = [p.getJointState(body, i)[0] for i in joints]
                currj = np.array(currj)
                diffj = targj - currj
                if all(np.abs(diffj) < 1e-2):
                    return False

                # Move with constant velocity
                norm = np.linalg.norm(diffj)
                v = diffj / norm if norm > 0 else 0
                stepj = currj + v * speed
                gains = np.ones(len(joints))
                p.setJointMotorControlArray(
                    bodyIndex=body,
                    jointIndices=joints,
                    controlMode=p.POSITION_CONTROL,
                    targetPositions=stepj,
                    positionGains=gains,
                )
                self.step_counter += 1
                self.step_simulation()

            print(f"Warning: movej exceeded {timeout} second timeout. Skipping.")
            return True

        else:
            body = effector.body
            t0 = time.time()
            while (time.time() - t0) < timeout / 4:
                # Move with constant velocity
                p.setJointMotorControl2(
                    body, 1, p.VELOCITY_CONTROL, targetVelocity=targj * speed, force=10
                )
                # self.step_counter += 1
                self.step_simulation()
                if effector.gripper_force()[1] > 1 or np.isclose(
                    effector.gripper_state()[1], 0, atol=0.01
                ):
                    return False

            print(f"Warning: movej exceeded {timeout/4} second timeout. Skipping.")
            return True

    def start_rec(self, video_filename):
        assert self.record_cfg

        if not self.video_path:
            video_path = self.record_cfg["save_video_path"]
        else:
            video_path = self.video_path

        # make video directory
        if not os.path.exists(video_path):
            os.makedirs(video_path)

        # close and save existing writer
        if hasattr(self, "video_writer"):
            self.video_writer.close()

        # initialize writer
        self.video_writer = imageio.get_writer(
            os.path.join(video_path, f"{video_filename}.mp4"),
            fps=self.record_cfg["fps"],
            format="FFMPEG",
            codec="h264",
        )
        # p.setRealTimeSimulation(False)
        self.save_video = True

    def end_rec(self):
        if hasattr(self, "video_writer"):
            self.video_writer.close()

        # p.setRealTimeSimulation(True)
        self.save_video = False

    def add_video_frame(self):
        # Render frame.
        config = self.agent_cams[0]
        image_size = (self.record_cfg["video_height"], self.record_cfg["video_width"])
        color, depth, _ = self.render_camera(config, image_size, shadow=0)
        color = np.array(color)

        if (
            hasattr(self.record_cfg, "blender_render")
            and self.record_cfg["blender_render"]
        ):
            # print("add blender key frame")
            self.blender_recorder.add_keyframe()

        # Add language instruction to video.
        if self.record_cfg["add_text"]:
            lang_goal = self.get_lang_goal()
            reward = f"Success: {self.task.get_reward():.3f}"

            font = cv2.FONT_HERSHEY_DUPLEX
            font_scale = 0.65
            font_thickness = 1

            # Write language goal.
            line_length = 60
            line_count = len(lang_goal) // line_length + 1
            for i in range(line_count):
                lang_textsize = cv2.getTextSize(
                    lang_goal[i * line_length : (i + 1) * line_length],
                    font,
                    font_scale,
                    font_thickness,
                )[0]
                lang_textX = (image_size[1] - lang_textsize[0]) // 2
                color = cv2.putText(
                    color,
                    lang_goal[i * line_length : (i + 1) * line_length],
                    org=(lang_textX, 570 + i * 30),  # 600
                    fontScale=font_scale,
                    fontFace=font,
                    color=(0, 0, 0),
                    thickness=font_thickness,
                    lineType=cv2.LINE_AA,
                )

            ## Write Reward.
            # reward_textsize = cv2.getTextSize(reward, font, font_scale, font_thickness)[0]
            # reward_textX = (image_size[1] - reward_textsize[0]) // 2
            #
            # color = cv2.putText(color, reward, org=(reward_textX, 634),
            #                     fontScale=font_scale,
            #                     fontFace=font,
            #                     color=(0, 0, 0),
            #                     thickness=font_thickness, lineType=cv2.LINE_AA)

            color = np.array(color)

        if "add_task_text" in self.record_cfg and self.record_cfg["add_task_text"]:
            lang_goal = self.get_task_name()
            reward = f"Success: {self.task.get_reward():.3f}"

            font = cv2.FONT_HERSHEY_DUPLEX
            font_scale = 1
            font_thickness = 2

            # Write language goal.
            line_length = 60
            line_count = len(lang_goal) // line_length + 1
            lang_textsize = cv2.getTextSize(
                lang_goal, font, font_scale, font_thickness
            )[0]
            lang_textX = (image_size[1] - lang_textsize[0]) // 2

            color = cv2.putText(
                color,
                lang_goal,
                org=(lang_textX, 570 + line_count * 30),
                fontScale=font_scale,
                fontFace=font,
                color=(255, 0, 0),
                thickness=font_thickness,
                lineType=cv2.LINE_AA,
            )

            color = np.array(color)

        self.video_writer.append_data(color)

    def movep(self, pose, speed=0.001):
        """Move UR5 to target end effector pose."""
        targj = self.solve_ik(pose)
        return self.movej(targj, speed)

    def solve_ik(self, pose):
        """Calculate joint configuration with inverse kinematics."""
        joints = p.calculateInverseKinematics(
            bodyUniqueId=self.ur5,
            endEffectorLinkIndex=self.ee_tip,
            targetPosition=pose[0],
            targetOrientation=pose[1],
            lowerLimits=[-3 * np.pi / 2, -2.3562, -17, -17, -17, -17],
            upperLimits=[-np.pi / 2, 0, 17, 17, 17, 17],
            jointRanges=[np.pi, 2.3562, 34, 34, 34, 34],  # * 6,
            restPoses=np.float32(self.homej).tolist(),
            maxNumIterations=100,
            residualThreshold=1e-5,
        )
        joints = np.float32(joints)
        joints[2:] = (joints[2:] + np.pi) % (2 * np.pi) - np.pi
        return joints

    def _get_obs(self):
        # Get RGB-D camera image observations.
        obs = {"color": (), "depth": ()}
        for config in self.agent_cams:
            color, depth, _ = self.render_camera(config)
            obs["color"] += (color,)
            obs["depth"] += (depth,)

        return obs

    def get_object_pose(self, obj_id):
        pose = p.getBasePositionAndOrientation(obj_id)
        return pose[0], self.ignore_roll_pitch(pose[1])

    def get_bounding_box(self, obj_id):
        return p.getAABB(obj_id)

    """ ----------------------------------- CapRavens utils ------------------------------- """

    def on_top_of(self, obj_a, obj_b):
        """
        check if obj_a is on top of obj_b
        condition 1: l2 distance on xy plane is less than a threshold
        condition 2: obj_a is higher than obj_b
        """
        obj_a_pos = self.get_obj_pos(obj_a)
        obj_b_pos = self.get_obj_pos(obj_b)
        xy_dist = np.linalg.norm(obj_a_pos[:2] - obj_b_pos[:2])
        if obj_b in utils.CORNER_POS:
            is_near = xy_dist < 0.06
            return is_near
        elif "bowl" in obj_b or "container" in obj_b or "fixture" in obj_b:
            is_near = xy_dist < 0.06
            is_higher = obj_a_pos[2] > obj_b_pos[2]
            return is_near and is_higher
        else:
            is_near = xy_dist < 0.04
            is_higher = obj_a_pos[2] > obj_b_pos[2]
            return is_near and is_higher

    def get_obj_id(self, obj_name, count=1):

        if type(obj_name) in (str, np.str_):
            if "with obj_id" in obj_name:
                id_pattern = r"obj_id\s+(\d+)"
                id_match = re.search(id_pattern, obj_name)
                obj_id = id_match.group(1)
                return int(obj_id)
        if type(obj_name) == list:
            if all("with obj_id" in obj for obj in obj_name):
                return [int(re.findall(r"-?\d+", obj)[0]) for obj in obj_name]

        obj_id = utils.find_best_keys(self.task.obj, obj_name)
        if len(obj_id) == 0:
            print(f'requested_name: "{obj_name}"')
            print(f"available_id_and_object:\n{self.object_list}")
            return None
        if count == 1:
            return obj_id[0]
        return obj_id[:count] if count != -1 else obj_id

    def get_obj_pos(self, obj_name, count=1):
        if type(obj_name) == int:
            position = p.getBasePositionAndOrientation(obj_name)[0]
            return [position]
        if isinstance(obj_name, list) and all(
            element in utils.CORNER_POS for element in obj_name
        ):
            position = np.float32(np.array(utils.CORNER_POS[obj_name]))
            return [position]
        else:
            pick_id = self.get_obj_id(obj_name, count)
            if "zone" in obj_name:
                if type(pick_id) == int:
                    pick_id = [pick_id]
                bias = np.array([0.01, 0, -0.01])
                position = [
                    np.array(p.getBasePositionAndOrientation(id)[0])
                    + np.array(
                        utils.rotation_to_rotation_matrix(
                            p.getBasePositionAndOrientation(id)[1]
                        )
                    )
                    @ bias
                    for id in pick_id
                ]
                position = [tuple(pos) for pos in position]
            else:
                if type(pick_id) == int:
                    pick_id = [pick_id]
                position = [p.getBasePositionAndOrientation(id)[0] for id in pick_id]
        return position

    def get_obj_rot(self, obj_id, count=1):
        return p.getBasePositionAndOrientation(obj_id)[1]

    def get_ee_pose(self):
        ee_poses = p.getLinkState(self.ur5, self.ee_tip)
        return ee_poses

    def ignore_roll_pitch(self, rotation):
        rotation = utils.quatXYZW_to_eulerXYZ(rotation)
        rotation = [0, 0, rotation[2]]
        rotation = utils.eulerXYZ_to_quatXYZW(rotation)
        return rotation


class EnvironmentNoRotationsWithHeightmap(Environment):
    """Environment that disables any rotations and always passes [0, 0, 0, 1]."""

    def __init__(self, assets_root, task=None, disp=False, shared_memory=False, hz=240):
        super(EnvironmentNoRotationsWithHeightmap, self).__init__(
            assets_root, task, disp, shared_memory, hz
        )

        heightmap_tuple = [
            gym.spaces.Box(0.0, 20.0, (320, 160, 3), dtype=np.float32),
            gym.spaces.Box(0.0, 20.0, (320, 160), dtype=np.float32),
        ]
        self.observation_space = gym.spaces.Dict(
            {
                "heightmap": gym.spaces.Tuple(heightmap_tuple),
            }
        )
        self.action_space = gym.spaces.Dict(
            {
                "pose0": gym.spaces.Tuple((self.position_bounds,)),
                "pose1": gym.spaces.Tuple((self.position_bounds,)),
            }
        )

    def step(self, action=None):
        """Execute action with specified primitive.

        Args:
          action: action to execute.

        Returns:
          (obs, reward, done, info) tuple containing MDP step data.
        """
        if action is not None:
            action = {
                "pose0": (action["pose0"][0], [0.0, 0.0, 0.0, 1.0]),
                "pose1": (action["pose1"][0], [0.0, 0.0, 0.0, 1.0]),
            }
        return super(EnvironmentNoRotationsWithHeightmap, self).step(action)

    def _get_obs(self):
        obs = {}

        color_depth_obs = {"color": (), "depth": ()}
        for config in self.agent_cams:
            color, depth, _ = self.render_camera(config)
            color_depth_obs["color"] += (color,)
            color_depth_obs["depth"] += (depth,)
        cmap, hmap = utils.get_fused_heightmap(
            color_depth_obs, self.agent_cams, self.task.bounds, pix_size=0.003125
        )
        obs["heightmap"] = (cmap, hmap)
        return obs
