import math
import time
from typing import List, Optional, Tuple, Union
import logging

import numpy as np
import numpy.typing as npt
import torch

try:
    from torchcontrol.transform import Rotation as R
    from torchcontrol.transform import Transformation as torchcontrol_T
except Exception:
    from torchcontrol.utils.transform import Rotation as R
    from torchcontrol.utils.transform import Transformation as torchcontrol_T
from torchcontrol.utils.tensor_utils import to_tensor

from furniture_bench.utils.pose import is_similar_rot, rot_mat
from furniture_bench.config import config
from furniture_bench.controllers.osc import osc_factory
from furniture_bench.envs.initialization_mode import Randomness
from furniture_bench.robot.robot_state import PandaState, PandaError
import furniture_bench.utils.transform as T
from furniture_bench.controllers.impedance import HybridJointImpedanceControl
from furniture_bench.controllers.trajectory import JointTrajectoryExecutor
from furniture_bench.async_utils.async_ws_client import AsyncWebsocketClient
import threading
import time

from robosuite.controllers.interpolators import TOPPRAInterpolator

import torch.nn.functional as F

logger = logging.getLogger("panda")

GRIPPER_SPEED = 10.0

def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
    """
    Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
    using Gram--Schmidt orthogonalization per Section B of [1].
    Args:
        d6: 6D rotation representation, of size (*, 6)

    Returns:
        batch of rotation matrices of size (*, 3, 3)

    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
    On the Continuity of Rotation Representations in Neural Networks.
    IEEE Conference on Computer Vision and Pattern Recognition, 2019.
    Retrieved from http://arxiv.org/abs/1812.07035
    """

    a1, a2 = d6[..., :3], d6[..., 3:]
    b1 = F.normalize(a1, dim=-1)
    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
    b2 = F.normalize(b2, dim=-1)
    b3 = torch.cross(b1, b2, dim=-1)
    return torch.stack((b1, b2, b3), dim=-2)


def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
    """
    Returns torch.sqrt(torch.max(0, x))
    but with a zero subgradient where x is 0.
    """
    ret = torch.zeros_like(x)
    positive_mask = x > 0
    ret[positive_mask] = torch.sqrt(x[positive_mask])
    return ret


def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to quaternions.

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).

    Returns:
        quaternions with real part first, as tensor of shape (..., 4).
    """
    if matrix.size(-1) != 3 or matrix.size(-2) != 3:
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")

    batch_dim = matrix.shape[:-2]
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
        matrix.reshape(batch_dim + (9,)), dim=-1
    )

    q_abs = _sqrt_positive_part(
        torch.stack(
            [
                1.0 + m00 + m11 + m22,
                1.0 + m00 - m11 - m22,
                1.0 - m00 + m11 - m22,
                1.0 - m00 - m11 + m22,
            ],
            dim=-1,
        )
    )

    # we produce the desired quaternion multiplied by each of r, i, j, k
    quat_by_rijk = torch.stack(
        [
            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
        ],
        dim=-2,
    )

    # We floor here at 0.1 but the exact level is not important; if q_abs is small,
    # the candidate won't be picked.
    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))

    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
    # forall i; we pick the best-conditioned one (with the largest denominator)

    return quat_candidates[
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :  # pyre-ignore[16]
    ].reshape(batch_dim + (4,))


class Panda:
    """PandaArm with custom methods."""

    def __init__(
        self,
        robot_config,
        max_gripper_width: float = None,
        binary_grasping: bool = True,
        abs_action: bool = False,
        act_rot_repr: bool = False,
        gain_scale: float = 1.0,
        vel_gain_scale: float = 1.0,
        toppra: bool = False,
        no_interpolation: bool = False,
    ):
        """
        Args:
            robot_config: Robot configuration.
            randomness: Randomize the robot initial pose.
        """

        from polymetis import GripperInterface, RobotInterface

        if config["robot"]["server_ip"] == "":
            from rich import print

            print(f"[bold red]SERVER_IP is not defined.[/bold red]")
            raise ValueError(f"SERVER_IP is not defined.")

        self.robot_config = robot_config
        if robot_config["server_ip"] is None:
            raise ValueError("Please specify the server IP address.")

        self.arm = RobotInterface(
            ip_address=robot_config["server_ip"], enforce_version=False
        )
        self.gripper = GripperInterface(ip_address=robot_config["server_ip"])
        self.max_gripper_width = self.gripper.metadata.max_width if max_gripper_width is None else max_gripper_width
        self.binary_grasping = binary_grasping
        
        from rich import print
        print("arm")
        print(self.arm.metadata.urdf_file)

        self.reset_joints = torch.tensor(robot_config["reset_joints"])

        self.arm.set_home_pose(self.reset_joints)
        self.dof = 7

        self.gain_scale = gain_scale
        self.vel_gain_scale = vel_gain_scale

        print("gain_scale", gain_scale)
        print("vel_gain_scale", vel_gain_scale)

        # Positinoal and velocity gains for robot control.
        # self.kp = torch.tensor([40, 40, 40, 25.0, 20.0, 25.0])
        self.kp = torch.tensor([80, 80, 80, 50.0, 40.0, 50.0]) * gain_scale
        self.kv = torch.ones((6,)) * torch.sqrt(self.kp) * 2.0

        self.grasp_margin = 0.02 - 0.001  # To prevent repeating open an close actions.

        self.max_go_time = 2.5  # 2.5 seconds maximum to move the robot.

        # Count how many times the robot has stopped moving in a row.
        # This is used to declare "done" when the robot stopped moving.
        self.motion_stopped_counter = 0
        self.is_fr3 = robot_config["FR3"]
        
        self.abs_action = abs_action
        self.act_rot_repr = act_rot_repr
        self.toppra = toppra
        self.no_interpolation = no_interpolation
        
        self.last_open = True
        self._state_update_thread = None
        self._stop_event = threading.Event()
        self._state_lock = threading.RLock()
        self._error_tracking_thread = None
        self._replan_chunk_index = 0
        
        self.ee_pos_errors = []
        self.ee_ori_errors = []
        self.joint_errors = []
        
        # For time-based step calculation & gripper sync
        self._last_policy_update_time = None
        self._last_set_trajectory_length = 0
        self._gripper_schedule = []
        self._waypoint_timesteps = []
        self._gripper_thread = None
        self._candidate_gripper_schedule = None
        
        # For replanning synchronization
        self._replan_id = 0
        self._last_processed_replan_id = -1
        self.controllable_set_sizes = []
        self.initial_controllable_set_sizes = []
        
        self._new_traj_received = False
        self._is_pd_fallback = False
        self.REPLAN_WAYPOINT_THRESHOLD = 100 # no replan except PD fallback
        self._last_replan_waypoint_idx = -1
        self._num_waypoints_before_current_segment = 0
        self._waypoints_in_segment = 0
        self._current_waypoints = []
        self._total_completed_waypoints_last_replan = 0

        if self.toppra:
            self.toppra_client = AsyncWebsocketClient(port=8767)
            self.toppra_client.set_message_handler(self._toppra_message_handler)

    def start_async_services(self):
        """Start background services like the Toppra client."""
        if self.toppra and self.toppra_client and not self.toppra_client.running:
            self.toppra_client.start()
            self._start_state_update_thread()
            self._start_gripper_supervisor_thread()
            self._start_error_tracking_thread()

    def _start_state_update_thread(self):
        self._stop_event.clear()
        self._state_update_thread = threading.Thread(target=self._send_state_updates)
        self._state_update_thread.daemon = True
        self._state_update_thread.start()

    def _start_error_tracking_thread(self):
        self._stop_event.clear()
        self._error_tracking_thread = threading.Thread(target=self._error_tracking_loop)
        self._error_tracking_thread.daemon = True
        self._error_tracking_thread.start()

    def _start_gripper_supervisor_thread(self):
        self._stop_event.clear()
        self._gripper_thread = threading.Thread(target=self._run_gripper_supervisor)
        self._gripper_thread.daemon = True
        self._gripper_thread.start()

    def _run_gripper_supervisor(self):
        logged_no_schedule = False
        last_completed_in_segment = -1
        while not self._stop_event.is_set():
            with self._state_lock:
                schedule = self._gripper_schedule
                if self._last_policy_update_time is None:
                    completed_in_segment = 0
                else:
                    completed_in_segment = self.get_total_completed_waypoints() - self._num_waypoints_before_current_segment
            if not schedule:
                if not logged_no_schedule:
                    logger.info(f"No gripper schedule. Waiting for commands.")
                    logged_no_schedule = True
                time.sleep(0.01)
                continue
            
            logger.info(f"Completed in segment: {completed_in_segment}, Last completed in segment: {last_completed_in_segment}")
            if completed_in_segment != last_completed_in_segment:
                last_completed_in_segment = completed_in_segment
                if completed_in_segment >= len(schedule):
                    logger.warning(f"Completed in segment is greater than the length of the schedule. Skipping gripper state change.")
                    continue
                logger.info(f"Changing gripper state: {schedule[completed_in_segment]}")
                self._change_gripper_state(schedule[completed_in_segment])
                
            # logged_no_schedule = False
            
            # # Use a while loop to process all due gripper commands
            # # time_elapsed = time.time() - update_time
            # while schedule and completed_in_segment >= schedule[0][0]:
            #     next_waypoint_idx, gripper_command = schedule[0]
                
            #     if next_waypoint_idx < self._waypoints_in_segment:
            #         self._change_gripper_state(gripper_command)
            #         with self._state_lock:
            #             schedule.pop(0)
            #     else:
            #         # Should not happen, but as a safeguard
            #         with self._state_lock:
            #             schedule.pop(0)
            
            logger.info(f"Gripper state: {self.last_grasp}")
            time.sleep(0.005) # 200Hz check

    def _send_state_updates(self):
        while not self._stop_event.is_set():
            with self._state_lock:
                update_time = self._last_policy_update_time
                timesteps = self._waypoint_timesteps
                is_pd_fallback = self._is_pd_fallback
                last_replan_waypoint_idx = self._last_replan_waypoint_idx
                total_completed_waypoints_last_replan = self._total_completed_waypoints_last_replan
            
            replan = False
            completed_in_segment = 0
            
            if update_time is None: # wait for the first trajectory to be received
                time.sleep(0.01)
                continue
            else:
                # Adaptive phase
                time_elapsed = time.time() - update_time
                

                if timesteps:
                    # Find the index of the last waypoint that should have been passed
                    current_waypoint_idx = np.searchsorted(timesteps, time_elapsed, side='right') - 1
                    
                    replan_threshold = 0 if is_pd_fallback else self.REPLAN_WAYPOINT_THRESHOLD
                    if current_waypoint_idx >= last_replan_waypoint_idx + replan_threshold:
                        replan = True
                        logger.info(f"Triggering replan. Current waypoint index {current_waypoint_idx} exceeds threshold for {'PD fallback' if self._is_pd_fallback else 'replan'} from last replan index {self._last_replan_waypoint_idx}.")
                        with self._state_lock:
                            self._last_replan_waypoint_idx = current_waypoint_idx

            if replan:
                with self._state_lock:
                    self._replan_id += 1
                    # --- Create the state snapshot ---
                    snapshot = {
                        "replan_id": self._replan_id,
                        "total_completed_waypoints": self.get_total_completed_waypoints(),
                        "_num_waypoints_before_current_segment": self._num_waypoints_before_current_segment,
                        "_last_policy_update_time": self._last_policy_update_time,
                        "_waypoint_timesteps": self._waypoint_timesteps,
                    }

                joint_pos = self.arm.get_joint_positions().numpy()
                joint_vel = self.arm.get_joint_velocities().numpy()
                state_update = {
                    "joint_pos": joint_pos, 
                    "joint_vel": joint_vel, 
                    "replan": True,
                    "completed_in_segment": self.get_total_completed_waypoints(**snapshot) - total_completed_waypoints_last_replan,
                    **snapshot
                }
                self.toppra_client.send_message({"state_update": state_update})
            
            time.sleep(1.0 / 30) # Still check at 30Hz

    def _error_tracking_loop(self):
        last_completed_waypoints = 0
        try:
            while not self._stop_event.is_set():
                with self._state_lock:
                    # Make a copy of these to avoid holding lock for long
                    current_waypoints = list(self._current_waypoints)
                    num_waypoints_before = self._num_waypoints_before_current_segment

                if len(current_waypoints) == 0:
                    time.sleep(0.01)
                    continue

                total_completed_waypoints = self.get_total_completed_waypoints()
                logger.info(f"[ErrorTracker] Total completed waypoints: {total_completed_waypoints}, num waypoints before: {num_waypoints_before}")

                if total_completed_waypoints > last_completed_waypoints:
                    completed_in_segment = total_completed_waypoints - num_waypoints_before
                    
                    # process the most recent completed waypoint
                    waypoint_idx = completed_in_segment - 1

                    if 0 <= waypoint_idx < len(current_waypoints):
                        # Get current state using get_robot_state for consistency
                        robot_state = self.arm.get_robot_state()
                        current_joint_pos = torch.tensor(robot_state.joint_positions, dtype=torch.float32)
                        
                        try:
                            current_ee_pos_np, current_ee_quat_xyzw_np = T.mat2pose(np.array(robot_state.ee_pose).reshape(4, 4).T)
                            current_ee_pos_torch = torch.from_numpy(current_ee_pos_np).float()
                            current_ee_quat_xyzw = torch.from_numpy(current_ee_quat_xyzw_np).float()
                        except AttributeError:
                            current_ee_pos_torch, current_ee_quat_xyzw = self.arm.robot_model.forward_kinematics(
                                current_joint_pos
                            )
                        
                        # Get target state from waypoint
                        target_joint_pos = torch.tensor(current_waypoints[waypoint_idx]).float().to(current_joint_pos.device)
                        target_ee_pos_torch, target_ee_quat_xyzw = self.arm.robot_model.forward_kinematics(target_joint_pos)
                        
                        # Calculate errors
                        joint_err = torch.linalg.norm(current_joint_pos - target_joint_pos).item()
                        pos_err = torch.linalg.norm(current_ee_pos_torch - target_ee_pos_torch).item()
                        
                        # rotation error
                        current_rot_mat = R.from_quat(current_ee_quat_xyzw).as_matrix()
                        target_rot_mat = R.from_quat(target_ee_quat_xyzw).as_matrix()

                        rot_error_mat = torch.matmul(target_rot_mat.transpose(-2, -1), current_rot_mat)
                        rot_err_rad = torch.linalg.norm(R.from_matrix(rot_error_mat).as_rotvec()).item()

                        with self._state_lock:
                            self.joint_errors.append(joint_err)
                            self.ee_pos_errors.append(pos_err)
                            self.ee_ori_errors.append(rot_err_rad)
                        
                        logger.info(f"[ErrorTracker] Waypoint {waypoint_idx} completed. Joint err: {joint_err:.4f}, Pos err: {pos_err:.4f}, Rot err: {rot_err_rad:.4f}")

                    last_completed_waypoints = total_completed_waypoints

                time.sleep(0.01) # 100Hz polling
        except Exception as e:
            logger.error(f"[ErrorTracker] Thread crashed: {e}", exc_info=True)

    def _toppra_message_handler(self, message):
        if "trajectory" in message:
            # The entire echoed context is the message itself.
            echoed_context = message
            replan_id = echoed_context.get("replan_id", -1)

            with self._state_lock:
                if replan_id <= self._last_processed_replan_id:
                    logger.warning(f"Ignoring stale replan response. Current ID: {self._last_processed_replan_id}, Received ID: {replan_id}")
                    return
                self._last_processed_replan_id = replan_id
            
                self._is_pd_fallback = echoed_context.get("pd_fallback", False)
                if self._is_pd_fallback:
                    logger.warning("Received a trajectory from PD fallback. Will attempt to replan sooner.")
                    num_remaining_waypoints = self.get_num_remaining_waypoints()
                    logger.info(f"Num Remaining Waypoints: {num_remaining_waypoints}")
                    if num_remaining_waypoints > 0:
                        with self._state_lock:
                            self._total_completed_waypoints_last_replan = echoed_context["total_completed_waypoints"]
                        logger.info("Remaining waypoints > 0, using previous plan.")
                        return
                    else:
                        logger.info("Remaining waypoints == 0, using PD fallback.")

            # --- Use the echoed context for a historically consistent calculation ---
            completed_in_segment_at_replan = echoed_context.get("completed_in_segment", 0)

            with self._state_lock:
                # Calculate progress since the snapshot was taken.
                current_completed_in_segment = self.get_total_completed_waypoints(**echoed_context) - echoed_context["_num_waypoints_before_current_segment"]
                
                drift_waypoints = current_completed_in_segment - completed_in_segment_at_replan

                logger.info(f"Completed in segment at replan: {completed_in_segment_at_replan}. Current completed in segment: {current_completed_in_segment}. Drift waypoints: {drift_waypoints}")
            
            # Now, process the new trajectory from the message
            traj = echoed_context["trajectory"]
            qs = torch.tensor(traj["qs"], dtype=torch.float32)
            qds = torch.tensor(traj["qds"], dtype=torch.float32)
            qdds = torch.tensor(traj["qdds"], dtype=torch.float32)
            new_waypoints = echoed_context.get("waypoints", [])
            new_waypoint_timesteps = echoed_context.get("waypoint_timesteps", [])
            
            with self._state_lock: # Update the gripper schedule when replanning was successful
                if self._candidate_gripper_schedule is not None:
                    logger.info(f"Updating gripper schedule from with candidate gripper schedule with length {len(new_waypoints)}")
                    self._gripper_schedule = self._candidate_gripper_schedule[-len(new_waypoints):]
                    self._candidate_gripper_schedule = None
            
            # Adjust for drift
            if drift_waypoints > 0 and new_waypoint_timesteps:
                if drift_waypoints < len(new_waypoint_timesteps):
                    # Find the time and controller step to start from in the new trajectory.
                    t_drift = new_waypoint_timesteps[drift_waypoints - 1]
                    i_drift = int(t_drift * self.arm.hz)

                    if i_drift < len(qs):
                        # Slice the trajectories to remove parts that were passed during replan.
                        qs = qs[i_drift:]
                        qds = qds[i_drift:]
                        qdds = qdds[i_drift:]
                        
                        
                    else:
                        logger.warning("Drift adjustment would consume the entire new trajectory. Skipping update.")
                        return
                else:
                    logger.warning("Drift is larger than the new trajectory's waypoint count. Skipping update.")
                    return
                

            # if self.no_interpolation:
            #     logger.info("Using step-wise trajectory generation.")
            #     waypoints = np.array(echoed_context.get("waypoints", []))
                
            #     # Adjust waypoints for drift as well
            #     if drift_waypoints > 0 and len(waypoints) > drift_waypoints:
            #         waypoints = waypoints[drift_waypoints:]

            #     num_steps = 0
            #     if self._waypoint_timesteps:
            #         duration = self._waypoint_timesteps[-1] if self._waypoint_timesteps else 0
            #         num_steps = int(duration * self.arm.hz)

            #     qs_step = torch.zeros((num_steps, self.dof), dtype=torch.float32)
            #     qds_step = torch.zeros((num_steps, self.dof), dtype=torch.float32)
            #     qdds_step = torch.zeros((num_steps, self.dof), dtype=torch.float32)
                
            #     if num_steps > 0 and len(waypoints) > 0:
            #         last_step_idx = 0
            #         for i, waypoint in enumerate(waypoints):
            #             waypoint_time = self._waypoint_timesteps[i]
            #             current_step_idx = int(waypoint_time * self.arm.hz)
                        
            #             qs_step[last_step_idx:current_step_idx, :] = torch.tensor(waypoint, dtype=torch.float32)
            #             last_step_idx = current_step_idx

            #         # Fill remaining steps with the last waypoint
            #         if last_step_idx < num_steps:
            #             qs_step[last_step_idx:, :] = torch.tensor(waypoints[-1], dtype=torch.float32)
                
            #     qs, qds, qdds = qs_step, qds_step, qdds_step

            max_len = self.ctrl.N
            
            # Pad or truncate trajectories to match the controller's expected length
            traj_len = qs.shape[0]
            if traj_len < max_len:
                # Pad
                last_pos = qs[-1].unsqueeze(0)
                pos_padding = last_pos.repeat(max_len - traj_len, 1)
                vel_acc_padding = torch.zeros(max_len - traj_len, qs.shape[1], dtype=qs.dtype)

                qs_full = torch.cat([qs, pos_padding], dim=0)
                qds_full = torch.cat([qds, vel_acc_padding], dim=0)
                qdds_full = torch.cat([qdds, vel_acc_padding], dim=0)
            elif traj_len >= max_len:
                # Truncate
                qs_full = qs[:max_len]
                qds_full = qds[:max_len]
                qdds_full = qdds[:max_len]
                
            logger.info(f"Updating current policy with trajectory length {qs_full.shape[0]}.")
            self.arm.update_current_policy(
                {
                    "joint_pos_trajectory": qs_full,
                    "joint_vel_trajectory": qds_full,
                    "joint_acc_trajectory": qdds_full,
                    "i": torch.tensor(0.),
                }
            )
            with self._state_lock:
                if drift_waypoints > 0:
                    # Adjust waypoint timesteps and gripper schedule for the new, shorter trajectory.
                    self._current_waypoints = new_waypoints[drift_waypoints:]
                    self._waypoint_timesteps = [t - t_drift for t in new_waypoint_timesteps[drift_waypoints:]]
                    if self._candidate_gripper_schedule is not None:
                        self._gripper_schedule = self._candidate_gripper_schedule[-len(self._waypoint_timesteps):]
                    else:
                        self._gripper_schedule = self._gripper_schedule[-len(self._waypoint_timesteps):]
                    
                    if len(self._current_waypoints) != len(self._waypoint_timesteps):
                        logger.warning(f"Current waypoints and waypoint timesteps have different lengths. Current waypoints: {len(self._current_waypoints)}, waypoint timesteps: {len(self._waypoint_timesteps)}")
                    
                    # Adjust gripper schedule indices
                    # new_schedule = []
                    # for idx, cmd in self._gripper_schedule:
                    #     if idx >= drift_waypoints:
                    #         new_schedule.append((idx - drift_waypoints, cmd))
                    # self._gripper_schedule = new_schedule
                    
                    logger.info(f"Adjusted for {drift_waypoints} drift waypoints. Starting new trajectory from step {i_drift}.")
                else:
                    logger.info(f"No drift in waypoints. Using as is")
                    self._current_waypoints = new_waypoints
                    self._waypoint_timesteps = new_waypoint_timesteps
                    if self._candidate_gripper_schedule is not None:
                        self._gripper_schedule = self._candidate_gripper_schedule[-len(self._waypoint_timesteps):]
                    else:
                        self._gripper_schedule = self._gripper_schedule[-len(self._waypoint_timesteps):]
                    
                
                # Bank the new total progress.
                self._total_completed_waypoints_last_replan = echoed_context["total_completed_waypoints"]
                self._num_waypoints_before_current_segment = echoed_context["_num_waypoints_before_current_segment"] + current_completed_in_segment
                self._last_policy_update_time = time.time()
                self._last_set_trajectory_length = len(qs)
                self._new_traj_received = True
                self._last_replan_waypoint_idx = -1
                if echoed_context.get("request_type") == "replan_request":
                    self._replan_chunk_index += 1
                if "controllable_set_size" in echoed_context:
                    self.controllable_set_sizes.append(echoed_context["controllable_set_size"])
                if "initial_controllable_set_size" in echoed_context:
                    self.initial_controllable_set_sizes.append(echoed_context["initial_controllable_set_size"])
                
                expected_duration = self._waypoint_timesteps[-1] if self._waypoint_timesteps else 0
                logger.info(
                    f"SimPanda: Received new trajectory (len={self._last_set_trajectory_length}). "
                    f"Waypoints: {self._current_waypoints}. "
                    f"Base waypoints: {self._num_waypoints_before_current_segment}. "
                    f"Expected duration: {expected_duration:.4f}s. "
                    f"Timesteps: {self._waypoint_timesteps}"
                )

    def get_total_completed_waypoints(self, **historical_context):
        with self._state_lock:
            # Use historical context if provided, otherwise use current state.
            timesteps = historical_context.get("_waypoint_timesteps", self._waypoint_timesteps)
            update_time = historical_context.get("_last_policy_update_time", self._last_policy_update_time)
            base_waypoints = historical_context.get("_num_waypoints_before_current_segment", self._num_waypoints_before_current_segment)

            if not timesteps or update_time is None:
                return base_waypoints

            time_elapsed = time.time() - update_time
            completed_in_segment = np.searchsorted(timesteps, time_elapsed, side='right')
            
            return base_waypoints + completed_in_segment

    def get_replan_chunk_index(self):
        with self._state_lock:
            return self._replan_chunk_index

    def get_num_remaining_waypoints(self):
        """Get the number of waypoints remaining in the current trajectory."""
        with self._state_lock:
            if not self._waypoint_timesteps or self._last_policy_update_time is None:
                return 0
            
            total_waypoints_in_segment = len(self._waypoint_timesteps)
            
            time_elapsed = time.time() - self._last_policy_update_time
            completed_in_segment = np.searchsorted(self._waypoint_timesteps, time_elapsed, side='right')
            
            # logger.info(f"Total waypoints in segment: {total_waypoints_in_segment}, length of waypoint_timesteps: {len(self._waypoint_timesteps)}, completed in segment: {completed_in_segment}")
            
            remaining_waypoints = total_waypoints_in_segment - completed_in_segment
            return max(0, remaining_waypoints)

    def get_remaining_execution_time(self):
        """Get the remaining execution time for the current trajectory."""
        with self._state_lock:
            if not self._waypoint_timesteps or self._last_policy_update_time is None:
                return 0.0
            
            total_duration = self._waypoint_timesteps[-1]
            time_elapsed = time.time() - self._last_policy_update_time
            
            remaining_time = total_duration - time_elapsed
            return max(0.0, remaining_time)

    def get_expected_execution_during_inference(self, delay=0.):
        """
        Estimates how many waypoints will be executed from now until a specified time delay in the future.

        Args:
            delay (float): The time delay in seconds to project forward.

        Returns:
            int: The estimated number of waypoints that will be executed.
        """
        with self._state_lock:
            if not self._waypoint_timesteps or self._last_policy_update_time is None:
                return 0

            time_elapsed = time.time() - self._last_policy_update_time

            # Waypoints completed so far in this segment
            completed_now = np.searchsorted(
                self._waypoint_timesteps, time_elapsed, side="right"
            )

            # Waypoints expected to be completed after the delay
            future_time = time_elapsed + delay
            completed_future = np.searchsorted(
                self._waypoint_timesteps, future_time, side="right"
            )

            # The number of waypoints executed during the delay period
            expected_execution_count = completed_future - completed_now

            return expected_execution_count

    def get_controllable_set_sizes(self):
        sizes = list(self.controllable_set_sizes)
        self.controllable_set_sizes.clear()
        return sizes

    def get_initial_controllable_set_sizes(self):
        sizes = list(self.initial_controllable_set_sizes)
        self.initial_controllable_set_sizes.clear()
        return sizes

    def close(self):
        if self.toppra:
            self.toppra_client.stop()
        if self._state_update_thread is not None:
            self._stop_event.set()
            self._state_update_thread.join(timeout=1.0)
        if self._gripper_thread is not None:
            self._stop_event.set()
            self._gripper_thread.join(timeout=1.0)
        if self._error_tracking_thread is not None:
            self._stop_event.set()
            self._error_tracking_thread.join(timeout=1.0)
      
    def init_interpolator(self, interpolator: str = "TOPPRA", toppra_last_vel: float = 1.0):
        if interpolator == "TOPPRA":
            torque_limits = np.array(
                [[-87.0, -87.0, -87.0, -87.0, -12.0, -12.0, -12.0],
                [87.0, 87.0, 87.0, 87.0, 12.0, 12.0, 12.0]],
            ) * 0.99 # from joint_limits.yaml, franka_panda.urdf
            vel_limits = self.arm.robot_model.get_joint_velocity_limits().numpy() * 0.7
            print("vel_limits", vel_limits)
            self.interpolator = TOPPRAInterpolator(
                ndim=self.dof,
                controller_freq=self.arm.hz,
                policy_freq=self.robot_config["hz"],
                torque_limits=torque_limits,
                vel_limits=vel_limits,
            )

            self.toppra_last_vel = toppra_last_vel
        else:
            raise NotImplementedError("Only TOPPRA is supported for now.")

    def init_controller(self, kp: torch.Tensor, kv: torch.Tensor, controller: str = "OSC"):
        """Initialize the OSC controller.

        Args:
            kp: Position gain.
            kv: Velocity gain.
        """
        if controller == "OSC":
            ee_pos_current, ee_quat_current = self.get_ee_pose()
            ee_pos_current = torch.tensor(ee_pos_current, dtype=torch.float32)
            ee_quat_current = torch.tensor(ee_quat_current, dtype=torch.float32)
            self.ctrl = osc_factory(
                ee_pos_current=ee_pos_current,
                ee_quat_current=ee_quat_current,
                init_joints=self.reset_joints,
                kp=kp,
                kv=kv,
                position_limits=torch.tensor(self.robot_config["position_limits"]),
            )
        elif controller == "JOINT_POSITION":
            joint_pos_current = self.arm.get_joint_positions()
            self.ctrl = HybridJointImpedanceControl(
                joint_pos_current=joint_pos_current,
                Kq=self.arm.Kq_default * self.gain_scale,
                Kqd=self.arm.Kqd_default * self.vel_gain_scale,
                Kx=self.arm.Kx_default * self.gain_scale,
                Kxd=self.arm.Kxd_default * self.vel_gain_scale,
                robot_model=self.arm.robot_model,
                ignore_gravity=self.arm.use_grav_comp,
            )
        elif controller == "JOINT_POSITION_TRACKING":
            Kq = self.arm.Kq_default * self.gain_scale
            Kqd = self.arm.Kqd_default * self.vel_gain_scale
            Kx = self.arm.Kx_default * self.gain_scale
            Kxd = self.arm.Kxd_default * self.vel_gain_scale
            torque_limits = torch.tensor(
                [[-87.0, -87.0, -87.0, -87.0, -12.0, -12.0, -12.0],
                [87.0, 87.0, 87.0, 87.0, 12.0, 12.0, 12.0]],
            )
            self.ctrl = JointTrajectoryExecutor(
                joint_pos_current=self.arm.get_joint_positions(),
                max_length=24*100, # assume not slower than the original policy time # TODO make it more robust
                Kq=Kq,
                Kqd=Kqd,
                Kx=Kx,
                Kxd=Kxd,
                robot_model=self.arm.robot_model,
                ignore_gravity=self.arm.use_grav_comp,
                torque_limits=torque_limits,
            )
        else:
            raise NotImplementedError(f"Controller {controller} not implemented.")
        
        logger.info(f"Initialized controller {controller}. Sending torch policy.")

        self.arm.send_torch_policy(torch_policy=self.ctrl, blocking=False)

    def get_state(self) -> Tuple[Optional[PandaState], PandaError]:
        """Get state of the Panda arm and end-effector."""
        robot_state = self.arm.get_robot_state()
        gripper_state = self.gripper.get_state()
        if gripper_state is None:
            print("Could not get gripper state. Please rerun the gripper server.")
            return None, PandaError.Gripper

        ee_pos, ee_quat = None, None
        ee_pos_vel, ee_ori_vel = None, None
        try:
            ee_pos, ee_quat = T.mat2pose(np.array(robot_state.ee_pose).reshape(4, 4).T)
            jacobian = torch.tensor(robot_state.jacobian).reshape(7, 6).T

            ee_twist = jacobian @ torch.tensor(robot_state.joint_velocities)
            ee_pos_vel = ee_twist[:3].numpy()
            ee_ori_vel = ee_twist[3:].numpy()
        except AttributeError:
            ee_pos_tensor, ee_quat_tensor = self.arm.robot_model.forward_kinematics(
                torch.tensor(robot_state.joint_positions, dtype=torch.float32)
            )
            ee_pos = ee_pos_tensor.numpy()
            ee_quat = ee_quat_tensor.numpy()
        
        # with self._state_lock:
        #     if self._last_policy_update_time is None:
        #         i_val = 0
        #     else:
        #         time_since_update = time.time() - self._last_policy_update_time
        #         estimated_i = int(time_since_update * self.arm.hz)
        #         # Cap at the last known trajectory length
        #         i_val = min(estimated_i, self._last_set_trajectory_length - 1)
        #         if i_val < 0: i_val = 0

        return (
            PandaState(
                joint_positions=np.array(robot_state.joint_positions),
                joint_velocities=np.array(robot_state.joint_velocities),
                joint_torques=np.array(robot_state.joint_torques_computed),
                ee_pos=ee_pos,
                ee_quat=ee_quat,
                ee_pos_vel=ee_pos_vel,
                ee_ori_vel=ee_ori_vel,
                gripper_width=np.array([gripper_state.width], dtype=np.float32),
                # i=i_val,
            ),
            PandaError.OK,
        )
        
    def get_joint_state(self) -> np.ndarray:
        """Get the current state of the leader robot.

        Returns:
            T: The current state of the leader robot.
        """
        robot_joints = self.arm.get_joint_positions()
        gripper_pos = self.gripper.get_state()
        pos = np.append(robot_joints, gripper_pos.width / 0.09)
        return pos

    def solve_inverse_kinematics(
        self,
        position: torch.Tensor,
        orientation: torch.Tensor,
        q0: torch.Tensor,
        tol: float = 1e-3,
        dt: float = 0.1,
        max_iters: int = 1000,
    ) -> Tuple[torch.Tensor, bool]:
        """Compute inverse kinematics given desired EE pose"""
        # Call IK
        joint_pos_output = self.arm.robot_model.inverse_kinematics(
            position, orientation, rest_pose=q0, dt=dt, max_iters=max_iters
        )

        # Check result
        pos_output, quat_output = self.arm.robot_model.forward_kinematics(joint_pos_output)
        pose_desired = torchcontrol_T.from_rot_xyz(R.from_quat(orientation), position)
        pose_output = torchcontrol_T.from_rot_xyz(R.from_quat(quat_output), pos_output)
        err = torch.linalg.norm((pose_desired * pose_output.inv()).as_twist())
        ik_sol_found = err < tol

        return joint_pos_output, ik_sol_found

    def get_inv_dyn(self):
        def inv_dyn(joint_pos, joint_vel, joint_acc):
            return self.arm.robot_model.inverse_dynamics(to_tensor(joint_pos), to_tensor(joint_vel), to_tensor(joint_acc)).numpy()
        return inv_dyn

    def command_joint_state(self, joint_state: np.ndarray, joint_pos_track: bool = False) -> None:
        """Command the leader robot to a given state.

        Args:
            joint_state (np.ndarray): The state to command the leader robot to.
        """
        import torch

        # print("Kq_default", self.arm.Kq_default)
        # print("Kqd_default", self.arm.Kqd_default)
        # print("Kx_default", self.arm.Kx_default)
        # print("Kxd_default", self.arm.Kxd_default)

        assert joint_state.shape[-1] == self.dof + 1, "joint_state should be (Dof + 1,)"
        
        print("joint_pos_track", joint_pos_track)
        print("self.toppra", self.toppra)

        if joint_pos_track:
            if self.toppra:
                # Create gripper schedule
                gripper_actions = joint_state[:, -1]
                # schedule = []

                # # Check for change at the beginning of the chunk
                # if np.sign(gripper_actions[0]) != np.sign(self.last_grasp):
                #     schedule.append((0, gripper_actions[0]))

                # # Check for changes within the chunk
                # for i in range(1, len(gripper_actions)):
                #     if np.sign(gripper_actions[i]) != np.sign(gripper_actions[i - 1]):
                #         schedule.append((i, gripper_actions[i]))

                with self._state_lock:
                    self._candidate_gripper_schedule = gripper_actions.tolist()
                    self._waypoints_in_segment = len(gripper_actions)

                with self._state_lock:
                    self._replan_id += 1
                    # --- Create the state snapshot ---
                    snapshot = {
                        "replan_id": self._replan_id,
                        "total_completed_waypoints": self.get_total_completed_waypoints(),
                        "_num_waypoints_before_current_segment": self._num_waypoints_before_current_segment,
                        "_last_policy_update_time": self._last_policy_update_time,
                        "_waypoint_timesteps": self._waypoint_timesteps,
                    }
                    joint_pos = self.arm.get_joint_positions().numpy()
                    joint_vel = self.arm.get_joint_velocities().numpy()
                    
                logger.info(f"Sending replan request with new {joint_state[:, :-1].shape[0]} waypoints.")

                replan_request = { 
                    "waypoints": joint_state[:, :-1], 
                    "completed_in_segment": self.get_total_completed_waypoints(**snapshot) - snapshot["_num_waypoints_before_current_segment"],
                    "joint_pos": joint_pos,
                    "joint_vel": joint_vel,
                    **snapshot
                }
                self.toppra_client.send_message({"replan_request": replan_request})
            elif isinstance(self.interpolator, TOPPRAInterpolator):
                self.interpolator.set_goal(joint_state[:, :-1], start=self.arm.get_joint_positions(), start_vel=self.arm.get_joint_velocities(), inv_dyn=self.arm.robot_model, last_vel=self.toppra_last_vel)
                qs, qds, qdds = self.interpolator.get_interpolated_trajectory()
                qs = torch.tensor(np.stack(qs, axis=0))
                qds = torch.tensor(np.stack(qds, axis=0))
                qdds = torch.tensor(np.stack(qdds, axis=0))
                print("qs", qs.shape)
                self.arm.update_current_policy(
                    {"joint_pos_trajectory": qs, "joint_vel_trajectory": qds, "joint_acc_trajectory": qdds, "i": torch.tensor(0.)}
                )
            else:
                raise NotImplementedError("Joint position tracking is only supported for TOPPRA interpolator.")

        else:
            update_idx = self.arm.update_desired_joint_positions(torch.tensor(joint_state[:-1]))
        
            # Gripper action. (from Panda.execute) # changed by D2 # for joint state control w/ SM data
            grasp = joint_state[-1]
            if (
                np.sign(grasp) != np.sign(self.last_grasp)
                and not self.gripper.get_state().is_moving
                and np.abs(grasp) > self.grasp_margin
            ):
                self._change_gripper_state(grasp)
                self.last_grasp = grasp

            if self._robot_stopped():
                self.motion_stopped_counter += 1
            else:
                self.motion_stopped_counter = 0

            # gripper method for GELLO # changed by D4
            # if joint_state[-1] > 0.25 and not self.gripper.get_state().is_moving: # Open the gripper:
            #     if self.last_open:
            #         self.gripper.grasp(speed=GRIPPER_SPEED, force=9.0, blocking=False)
            #     else:
            #         gripper_width = None
            #         width = self.max_gripper_width if gripper_width is None else gripper_width
            #         self.gripper.goto(width=width, speed=GRIPPER_SPEED, force=0.0, blocking=False)
                
            #     self.last_open = not self.last_open # Toggle the gripper state.
            
            
            # self.gripper.goto(width=(0.09 * (1 - joint_state[-1])), speed=1, force=1)

    def execute(
        self, action: npt.NDArray[np.float32], action_filtering: bool = True, cartesian_impedance: bool = False, joint_pos: bool = False, joint_pos_track: bool = False
    ) -> bool:
        """Execute robot action.

        Args:
            action: Action to execute, 7D for arm, 1D for the gripper.
        Returns: True if the action was successful, False otherwise.
        """
        if joint_pos_track:
            return self.execute_ee_sequence(action)

        # Setup frequencly.
        arm_action, grasp = action[:-1], action[-1]
        if self.act_rot_repr == "rot_6d":
            assert arm_action.shape[0] == 9, "arm_action should be 9D for rot_6d"
            rot_6d = torch.tensor(arm_action[3:9])
            rot_mat = rotation_6d_to_matrix(rot_6d)
            quat = matrix_to_quaternion(rot_mat).numpy()
            arm_action = np.concatenate([arm_action[:3], quat], axis=0)
        assert arm_action.shape[0] == 7, "arm_action should be 7D for quat"

        if not self.abs_action:
            if np.abs(arm_action[:3]).max() > 0.11:  # 11 cm.
                if action_filtering:
                    print(f"[env] Position action too big: {arm_action[:3]}, skipping it.")
                    return False
                else:
                    # Clip the action to be within the range.
                    arm_action[:3] = np.clip(arm_action[:3], -0.10, 0.10)
        # Arm action.
        if not self.abs_action: # Delta action.
            ee_pos, ee_quat = self.get_ee_pose()
            goal_ee_pos = torch.tensor(ee_pos, dtype=torch.float32) + torch.tensor(
                arm_action[:3], dtype=torch.float32
            )
            act_quat = arm_action[3:]

            goal_ee_quat = torch.tensor(
                T.quat_multiply(ee_quat, act_quat), dtype=torch.float32
            )
        else:
            goal_ee_pos = torch.tensor(arm_action[:3], dtype=torch.float32)
            goal_ee_quat = torch.tensor(arm_action[3:], dtype=torch.float32)
        
        if not self.is_fr3 or cartesian_impedance:
            print("Cartesian Impedance")
            self.arm.update_desired_ee_pose(
                position=goal_ee_pos, orientation=goal_ee_quat
            )
        elif joint_pos:
            print("Joint Position")
            joint_pos_current = self.arm.get_joint_positions()
            joint_pos_desired, success = self.solve_inverse_kinematics(
                goal_ee_pos, goal_ee_quat, joint_pos_current
            )
            # print("goal_ee_pos", goal_ee_pos)
            # print("goal_ee_quat", goal_ee_quat)
            # print("joint_pos_desired", joint_pos_desired)
            # print("joint_pos_current", joint_pos_current)
            if not success:
                print(
                    "(WARNING) Unable to find valid joint target. Skipping execution..."
                )
                return False

            # ee_pos, ee_quat = self.arm.get_ee_pose()
            # joint_pos_current_ik, success = self.solve_inverse_kinematics(
            #     ee_pos + torch.randn_like(ee_pos) * 0.01, ee_quat + torch.randn_like(ee_quat) * 0.01, joint_pos_current
            # )
            # print("ee_pos", ee_pos)
            # print("ee_quat", ee_quat)
            # print("pos_diff", goal_ee_pos - ee_pos)
            # print("quat_diff", T.quat_multiply(T.quat_inverse(ee_quat), goal_ee_quat))

            # print("joint_pos_current_ik", joint_pos_current_ik)

            self.arm.update_current_policy(
                {"joint_pos_desired": joint_pos_desired}
            )
        else:
            print("goal_ee_pos", goal_ee_pos)
            print("goal_ee_quat", goal_ee_quat)
            ee_pos, ee_quat = self.arm.get_ee_pose()
            print("ee_pos", ee_pos)
            print("ee_quat", ee_quat)
            self.arm.update_current_policy(
                {"ee_pos_desired": goal_ee_pos, "ee_quat_desired": goal_ee_quat}
            )
        # Gripper action.

        if (
            np.sign(grasp) != np.sign(self.last_grasp)
            and not self.gripper.get_state().is_moving
            and np.abs(grasp) > self.grasp_margin
        ):
            self._change_gripper_state(grasp)
            self.last_grasp = grasp

        if self._robot_stopped():
            self.motion_stopped_counter += 1
        else:
            self.motion_stopped_counter = 0         

        return True
    
    def execute_ee_sequence(self, ee_actions: npt.NDArray[np.float32]) -> bool:
        ee_actions = np.array(ee_actions)
        num_actions = ee_actions.shape[0]
        joint_actions = np.zeros((num_actions, self.dof + 1))
        
        # Get current joint position as starting point for IK
        current_joint_pos = self.arm.get_joint_positions()

        for i in range(num_actions):
            arm_action = ee_actions[i, :-1]
            grip_action = ee_actions[i, -1]
            
            if self.act_rot_repr == "rot_6d":
                rot_6d = torch.tensor(arm_action[3:9])
                rot_mat = rotation_6d_to_matrix(rot_6d)
                quat = matrix_to_quaternion(rot_mat).numpy()
                arm_action = np.concatenate([arm_action[:3], quat])

            goal_pos = torch.tensor(arm_action[:3], dtype=torch.float32)
            goal_quat = torch.tensor(arm_action[3:], dtype=torch.float32)

            # Use result of previous IK as seed for next
            joint_pos_desired, success = self.solve_inverse_kinematics(
                goal_pos, goal_quat, current_joint_pos
            )
            
            if not success:
                print(f"(WARNING) IK failed for waypoint {i}. Skipping execution of sequence.")
                return False
            
            joint_actions[i, :-1] = joint_pos_desired.numpy()
            joint_actions[i, -1] = grip_action
            
            current_joint_pos = joint_pos_desired
        
        self.command_joint_state(joint_actions, joint_pos_track=True)
        return True

    def execute_joint(
        self, action: npt.NDArray[np.float32], action_filtering: bool = True, joint_pos_track: bool = False
    ) -> bool:
        """Execute robot action.

        Args:
            action: Action to execute, 7D for arm, 1D for the gripper.
        Returns: True if the action was successful, False otherwise.
        """
            
        self.command_joint_state(action, joint_pos_track)

        return True

    def _motion_stopped_for_too_long(self) -> bool:
        """Check if the robot has stopped for too long."""
        if (
            self.motion_stopped_counter
            > self.robot_config["motion_stopped_counter_threshold"]
        ):
            print("[env] Robot stopped for too long.")
            self.motion_stopped_counter = 0
            return True

        return False

    def _robot_stopped(self) -> bool:
        return self.arm.get_joint_velocities().abs().max() < 0.0055

    def _change_gripper_state(self, grasp: float):
        if self.binary_grasping:
            if grasp < 0:
                self.open_gripper()
            else:
                self.close_gripper()
        else:
            self.gripper.goto(width=self.max_gripper_width * (1 - grasp), speed=10.0, force=1.0)

    def open_gripper(
        self, blocking: bool = False, gripper_width: Optional[float] = None
    ):
        width = self.max_gripper_width if gripper_width is None else gripper_width
        self.gripper.goto(width=width, speed=GRIPPER_SPEED, force=0.0, blocking=blocking)
        self.last_grasp = -1

    def open_gripper_delta(self, blocking: bool = False):
        width = self.gripper.get_state().width + 0.01
        self.open_gripper(blocking=blocking, gripper_width=width)

    def close_gripper(self, blocking: bool = False):
        # self.gripper.grasp(speed=GRIPPER_SPEED, force=9.0, blocking=blocking)
        self.gripper.goto(width=0.0, speed=GRIPPER_SPEED, force=9.0, blocking=blocking)
        self.last_grasp = 1

    def get_ee_pose(self):
        ee_pose = self.arm.get_ee_pose_mat()
        ee_pose = ee_pose.numpy()
        return T.mat2pose(ee_pose)

    def init_reset(self):
        self.open_gripper()
        self.arm.go_home(blocking=True)
        self.init_controller(self.kp, self.kv)
        self.last_grasp = -1
        self.motion_stopped_counter = 0
        self.ctrl.reset()
    
    def get_errors(self):
        with self._state_lock:
            pos_errors = list(self.ee_pos_errors)
            ori_errors = list(self.ee_ori_errors)
            joint_errors = list(self.joint_errors)
            self.ee_pos_errors.clear()
            self.ee_ori_errors.clear()
            self.joint_errors.clear()
        return pos_errors, ori_errors, joint_errors

    def init_controller_wapper(self):
        self.init_controller(self.kp, self.kv) # target eef pose control
    
    def reset(self, randomness=Randomness.LOW):
        self.open_gripper()
        self.arm.go_home(blocking=True)
        # TODO: developer3 implement
        self.init_controller_wapper() # target eef pose control
        if randomness in [
            Randomness.MEDIUM,
            Randomness.MEDIUM_COLLECT,
            Randomness.HIGH,
            Randomness.HIGH_COLLECT,
        ]:
            # Move z 5cm up so it doesn't collide.
            self.go_delta_pos([0, 0, 0.05])
            pos_noise = np.random.uniform(
                low=-config["robot"]["pos_noise_med"],
                high=config["robot"]["pos_noise_med"],
                size=3,
            )
            quat_noise = T.axisangle2quat(
                [
                    np.radians(
                        np.random.uniform(
                            -config["robot"]["rot_noise_med"],
                            config["robot"]["rot_noise_med"],
                        )
                    ),
                    np.radians(
                        np.random.uniform(
                            -config["robot"]["rot_noise_med"],
                            config["robot"]["rot_noise_med"],
                        )
                    ),
                    np.radians(
                        np.random.uniform(
                            -config["robot"]["rot_noise_med"],
                            config["robot"]["rot_noise_med"],
                        )
                    ),
                ]
            )
            self.go_delta(pos_noise, quat_noise)
        self.last_grasp = -1
        self.motion_stopped_counter = 0

        self.ee_pos_errors.clear()
        self.ee_ori_errors.clear()
        self.joint_errors.clear()

        self._last_policy_update_time = None
        self._last_set_trajectory_length = 0
        self._gripper_schedule = []
        self._waypoint_timesteps = []
        self._new_traj_received = False
        # self._replan_id = 0
        # self._last_processed_replan_id = -1
        self._is_pd_fallback = False
        self._last_replan_waypoint_idx = -1
        self._num_waypoints_before_current_segment = 0
        self._waypoints_in_segment = 0
        self._total_completed_waypoints_last_replan = 0
        self.controllable_set_sizes = []
        self.initial_controllable_set_sizes = []

        # TODO: Add failure checking.
        return True

    def gripper_face_front(self):
        self.go_rot_mat(rot_mat([np.pi, 0, 0]))

    def gripper_face_back(self):
        self.go_rot_mat(rot_mat([np.pi, 0, 0]) @ rot_mat([0, 0, np.pi]))

    def update_sleep(
        self,
        position: torch.Tensor,
        orientation: Optional[torch.Tensor] = None,
        sleep_time=2.0,
    ):
        """Update the end-effector pose and sleep for a given amount of time."""
        if not self.is_fr3:
            self.arm.update_desired_ee_pose(position=position, orientation=orientation)
        else:
            self.arm.update_current_policy(
                {"ee_pos_desired": position, "ee_quat_desired": orientation}
            )
        time.sleep(sleep_time)

    def go(
        self,
        goal_pos: Union[npt.NDArray[np.float32], list],
        goal_quat: Union[npt.NDArray[np.float32], list],
        z_last: bool = True,
    ) -> None:
        """Go to a desired pose within given time limit.

        Args:
            goal_pos: Goal position in robot coordinate.
            goal_quat: Goal orientation in robot coordinate.
            z_last: Whether the z-positional move should be done after every other moves.
        """
        if isinstance(goal_pos, list):
            goal_pos = np.array(goal_pos)
        if goal_pos.shape == (4,):
            # Homogeneous.
            goal_pos = goal_pos[:3]
        if isinstance(goal_quat, list):
            goal_quat = np.array(goal_quat)

        ee_pos, ee_quat = self.get_ee_pose()
        if z_last:
            same_z = goal_pos.copy()
            same_z[2] = ee_pos[2]
            self.go(same_z, goal_quat, z_last=False)

        start = time.time()
        while not (np.abs(ee_pos - goal_pos) < 0.005).all() or not is_similar_rot(
            T.quat2mat(ee_quat), T.quat2mat(goal_quat)
        ):
            self.update_sleep(
                position=torch.tensor(goal_pos),
                orientation=torch.tensor(goal_quat),
                sleep_time=0.2,
            )

            if time.time() - start > self.max_go_time:
                break
            ee_pos, ee_quat = self.get_ee_pose()

    def go_mat(self, goal_mat: npt.NDArray[np.float32]):
        """Matrix form input of `self.go` method."""
        goal_pos, goal_quat = T.mat2pose(goal_mat)
        self.go(goal_pos, goal_quat)

    def go_nearest_90_z(self):
        """Rotate end-effector to the nearest 90 degree angle in the z axis."""
        # Find the nearest 90 degree angle.
        ee_pos, ee_quat = self.get_ee_pose()
        mat = T.quat2mat(ee_quat)
        ee_frame_mat = rot_mat([np.pi, 0, 0]) @ mat
        euler_angles = T.mat2euler(ee_frame_mat)
        ee_z = math.degrees(euler_angles[2])
        sign = np.sign(ee_z)
        ee_z = np.abs(ee_z)

        goal_z = round(ee_z / 90) * 90

        if sign < 0:
            goal_z = -goal_z

        euler_angles[2] = math.radians(goal_z)
        robot_frame_mat = rot_mat([-np.pi, 0, 0]) @ T.euler2mat(euler_angles)
        goal_quat = T.mat2quat(robot_frame_mat)
        self.go(ee_pos, goal_quat)

    def go_pos(self, goal_pos):
        _, ee_quat = self.get_ee_pose()
        self.go(goal_pos, ee_quat, z_last=True)

    def go_rot(self, goal_quat):
        ee_pos, _ = self.get_ee_pose()
        self.go(ee_pos, goal_quat)

    def go_rot_mat(self, rot_mat):
        goal_quat = T.mat2quat(rot_mat)
        self.go_rot(goal_quat)

    def go_delta(self, delta_pos, delta_quat):
        ee_pos, ee_quat = self.get_ee_pose()
        goal_pos = ee_pos + delta_pos
        goal_quat = T.quat_multiply(delta_quat, ee_quat)
        self.go(goal_pos, goal_quat, z_last=False)

    def go_delta_xy(self, delta_xy):
        ee_pos, ee_quat = self.get_ee_pose()
        goal_pos = ee_pos
        goal_pos[0] = ee_pos[0] + delta_xy[0]
        goal_pos[1] = ee_pos[1] + delta_xy[1]
        self.go(goal_pos, ee_quat)

    def go_delta_pos(self, delta_pos):
        if isinstance(delta_pos, list):
            delta_pos = np.array(delta_pos)
        ee_pos, ee_quat = self.get_ee_pose()
        goal_pos = ee_pos + delta_pos
        self.go(goal_pos, ee_quat)

    def go_delta_quat(self, delta_quat):
        ee_pos, ee_quat = self.get_ee_pose()
        goal_quat = T.quat_multiply(ee_quat, delta_quat)
        self.go(ee_pos, goal_quat)

    def rotate_z(self, pi: bool):
        ee_pos, _ = self.get_ee_pose()
        # this is due to ee-frame and robot frame is different.
        # rotate robot_frame x-axis pi is ee_frame.
        rot = rot_mat([np.pi, 0, 0])
        if pi:
            rot = rot_mat([0, 0, np.pi]) @ rot
        rot = T.mat2quat(rot)

        self.go(ee_pos, rot)

    def tilt_place(self, pos):
        self.go_delta_pos([0, 0, 0.1])
        self.move_xy(pos)
        self.tilt()
        self.z_move(pos[2])
        self.open_gripper_delta(blocking=True)

    def tilt(self):
        goal_rot = T.mat2quat(rot_mat([0, np.pi / 2 + np.pi / 6, np.pi]))
        self.go_rot(goal_rot)

    def tilt_ee(self, angles: List[float]):
        """Tilt the end-effector (x,y,z) given angles."""
        assert len(angles) == 3
        goal = self.ee_to_robot_coord(
            rot_mat([angles[0], angles[1], angles[2]], hom=True)
        )
        self.go_rot(T.mat2quat(goal))

    def move_xy(self, pos: List[float]):
        """Move the end-effector (x,y) given position."""
        ee_pos, _ = self.get_ee_pose()
        if isinstance(pos, list):
            pos = np.array(pos)
        self.go_pos(np.concatenate([pos[:2], ee_pos[2:3]]))

    def move_z(self, z: float):
        ee_pos, _ = self.get_ee_pose()
        self.go_pos(np.concatenate([ee_pos[:2], [z]]))

    def z_move(self, z_pos):
        ee_pos, _ = self.get_ee_pose()
        ee_pos[2] = z_pos
        self.go_pos(ee_pos)

    def ee_clock_grasp_mat(self):
        return T.to_homogeneous([0.0, 0.0, 0.0], rot_mat([-np.pi / 2, np.pi, 0]))

    def ee_counter_clock_grasp_mat(self):
        return T.to_homogeneous(
            [
                0.0,
                0.0,
                0.0,
            ],
            rot_mat([-np.pi / 2, 0, 0]),
        )

    def to_robot_coord(self, pose):
        return config["robot"]["tag_base_from_robot_base"] @ pose

    def ee_to_robot_coord(self, pose):
        """Convert end-effector pose to robot coordinate."""
        rot = rot_mat([np.pi, 0, 0], hom=True)
        return rot @ pose

    def check_grasp_success(self):
        gripper_state = self.gripper.get_state()
        if gripper_state.width <= 0.001:
            return False
        return True

    def __del__(self):
        if self.arm is not None:
            self.arm.terminate_current_policy()
        if hasattr(self, "toppra_client"):
            self.toppra_client.stop()
        if self._state_update_thread is not None:
            self._stop_event.set()
            self._state_update_thread.join(timeout=1.0)
        if self._gripper_thread is not None:
            self._stop_event.set()
            self._gripper_thread.join(timeout=1.0)
        if self._error_tracking_thread is not None:
            self._stop_event.set()
            self._error_tracking_thread.join(timeout=1.0)