import mplib.planner
import mplib
import numpy as np
import pdb
import traceback
import numpy as np
import toppra as ta
from mplib.sapien_utils import SapienPlanner, SapienPlanningWorld
import transforms3d as t3d
import envs._GLOBAL_CONFIGS as CONFIGS
import os
import sapien 

# ********************** MplibPlanner **********************
class MplibPlanner:
    # links=None, joints=None
    def __init__(
        self,
        urdf_path,
        srdf_path,
        move_group,
        robot_origion_pose,
        robot_entity,
        planner_type="mplib_RRT",
        scene=None,
    ):
        super().__init__()
        ta.setup_logging("CRITICAL")  # hide logging

        links = [link.get_name() for link in robot_entity.get_links()]
        joints = [joint.get_name() for joint in robot_entity.get_active_joints()]

        if scene is None:
            self.planner = mplib.Planner(
                urdf=urdf_path,
                srdf=srdf_path,
                move_group=move_group,
                user_link_names=links,
                user_joint_names=joints,
                use_convex=False,
            )
            self.planner.set_base_pose(robot_origion_pose)
        else:
            planning_world = SapienPlanningWorld(scene, [robot_entity])
            self.planner = SapienPlanner(planning_world, move_group)

        self.planner_type = planner_type
        self.plan_step_lim = 2500
        self.TOPP = self.planner.TOPP

    def show_info(self):
        print("joint_limits", self.planner.joint_limits)
        print("joint_acc_limits", self.planner.joint_acc_limits)

    def plan_pose(
        self,
        now_qpos,
        target_pose,
        use_point_cloud=False,
        use_attach=False,
        arms_tag=None,
        try_times=2,
        log=True,
    ):
        result = {}
        result["status"] = "Fail"

        now_try_times = 1
        while result["status"] != "Success" and now_try_times < try_times:
            result = self.planner.plan_pose(
                goal_pose=target_pose,
                current_qpos=now_qpos,
                time_step=1 / 250,
                planning_time=2, 
                fix_joint_limits=False
                # rrt_range=0.05
                # =================== mplib 0.1.1 ===================
                # use_point_cloud=use_point_cloud,
                # use_attach=use_attach,
                # planner_name="RRTConnect"
            )
            now_try_times += 1

        if result["status"] != "Success":
            if log:
                print(f"\n {arms_tag} arm planning failed ({result['status']}) !")
        else:
            n_step = result["position"].shape[0]
            if n_step > self.plan_step_lim:
                if log:
                    print(f"\n {arms_tag} arm planning wrong! (step = {n_step})")
                result["status"] = "Fail"

        return result

    def plan_screw(
        self,
        now_qpos,
        target_pose,
        use_point_cloud=False,
        use_attach=False,
        arms_tag=None,
        log=False,
    ):
        """
        Interpolative planning with screw motion.
        Will not avoid collision and will fail if the path contains collision.
        """
        result = self.planner.plan_screw(
            goal_pose=target_pose,
            current_qpos=now_qpos,
            time_step=1 / 250,
            # =================== mplib 0.1.1 ===================
            # use_point_cloud=use_point_cloud,
            # use_attach=use_attach,
        )

        # plan fail
        if result["status"] != "Success":
            if log:
                print(f"\n {arms_tag} arm planning failed ({result['status']}) !")
            # return result
        else:
            n_step = result["position"].shape[0]
            # plan step lim
            if n_step > self.plan_step_lim:
                if log:
                    print(f"\n {arms_tag} arm planning wrong! (step = {n_step})")
                result["status"] = "Fail"

        return result

    def plan_path(
        self,
        now_qpos,
        target_pose,
        use_point_cloud=False,
        use_attach=False,
        arms_tag=None,
        log=True,
    ):
        """
        Interpolative planning with screw motion.
        Will not avoid collision and will fail if the path contains collision.
        """
        if self.planner_type == "mplib_RRT":
            result = self.plan_pose(
                now_qpos,
                target_pose,
                use_point_cloud,
                use_attach,
                arms_tag,
                try_times=10,
                log=log,
            )
        elif self.planner_type == "mplib_screw":
            result = self.plan_screw(now_qpos, target_pose, use_point_cloud, use_attach, arms_tag, log)

        return result

    def plan_grippers(self, now_val, target_val):
        num_step = 200  # TODO
        dis_val = target_val - now_val
        per_step = dis_val / num_step
        res = {}
        vals = np.linspace(now_val, target_val, num_step)
        res["num_step"] = num_step
        res["per_step"] = per_step  # dis per step
        res["result"] = vals
        return res


try:
    # ********************** CuroboPlanner (optional) **********************
    #from curobo.types.math import Pose as CuroboPose
    import time
    #from curobo.types.robot import JointState
    # from curobo.wrap.reacher.motion_gen import (
    #     MotionGen,
    #     MotionGenConfig,
    #     MotionGenPlanConfig,
    #     PoseCostMetric,
    # )
    # from curobo.util import logger
    import torch
    import yaml
   
    class CuroboPlanner:
        """
        A simplified planner that mimics CuroboPlanner interface but uses mplib internally
        """
        def __init__(self, robot_origion_pose, active_joints_name, all_joints, yml_path=None):
            super().__init__()
            ta.setup_logging("CRITICAL")
            
            self.robot_origion_pose = robot_origion_pose
            self.active_joints_name = active_joints_name
            self.all_joints = all_joints
            
            # Read frame bias from yml if provided
            self.frame_bias = [0, 0, 0]
            assert  yml_path and os.path.exists(yml_path)
            if yml_path and os.path.exists(yml_path):
                try:
                    with open(yml_path, 'r') as f:
                        yml_data = yaml.safe_load(f)
                    self.frame_bias = yml_data.get('planner', {}).get('frame_bias', [0, 0, 0])
                except:
                    raise
            
            # Store additional info needed for mplib planner
            self.mplib_planner = None
            self._initialized = False
            
        def initialize_mplib(self, urdf_path, srdf_path, move_group, robot_entity, scene=None):
            """
            Initialize the internal mplib planner with additional required information
            This should be called after the CuroboPlanner is created
            """
            self.mplib_planner = MplibPlanner(
                urdf_path=urdf_path,
                srdf_path=srdf_path,
                move_group=move_group,
                robot_origion_pose=self.robot_origion_pose,
                robot_entity=robot_entity,
                planner_type="mplib_RRT",
                scene=scene
            )
            self._initialized = True
        
        # def plan_path(self, curr_joint_pos, target_gripper_pose, constraint_pose=None, arms_tag=None):
        #     """Plan a single path"""
        #     if not self._initialized or self.mplib_planner is None:
        #         raise
        #         #return {"status": "Fail", "error": "mplib planner not initialized"}
            
        #     # Transform from world to arm's base frame (matching original CuroboPlanner)
        #     world_base_pose = np.concatenate([
        #         np.array(self.robot_origion_pose.p),
        #         np.array(self.robot_origion_pose.q)
        #     ])
        #     world_target_pose = np.concatenate([
        #         np.array(target_gripper_pose.p), 
        #         np.array(target_gripper_pose.q)
        #     ])
            
        #     # Transform to base frame
        #     target_pose_p, target_pose_q = self._trans_from_world_to_base(
        #         world_base_pose, world_target_pose
        #     )
            
        #     # Apply frame bias (important!)
        #     target_pose_p[0] += self.frame_bias[0]
        #     target_pose_p[1] += self.frame_bias[1]
        #     target_pose_p[2] += self.frame_bias[2]
            
        #     # Combine position and quaternion for mplib
        #     target_pose = list(target_pose_p) + list(target_pose_q)
            
        #     # Get joint positions for active joints only
        #     joint_indices = [self.all_joints.index(name) for name in self.active_joints_name 
        #                     if name in self.all_joints]
        #     joint_angles = [curr_joint_pos[index] for index in joint_indices]
        #     joint_angles = [round(angle, 5) for angle in joint_angles]  # avoid precision problems
            
        #     # Call mplib planner
        #     result = self.mplib_planner.plan_path(
        #         now_qpos=joint_angles,
        #         target_pose=target_pose,
        #         use_point_cloud=False,
        #         use_attach=False,
        #         arms_tag=arms_tag,
        #         log=True
        #     )
            
        #     # Ensure result format matches CuroboPlanner output
        #     if result.get("status") == "Success":
        #         if not isinstance(result.get("position"), np.ndarray):
        #             result["position"] = np.array(result["position"])
        #         if not isinstance(result.get("velocity"), np.ndarray):
        #             result["velocity"] = np.array(result["velocity"])
            
        #     return result
        
        
        def plan_path(self, curr_joint_pos, target_gripper_pose, constraint_pose=None, arms_tag=None):
            """Plan a single path"""
            
            if not self._initialized or self.mplib_planner is None:
                raise ValueError("mplib planner not initialized")
            
            # 不再进行坐标系转换，直接使用世界坐标系的位姿
            # 创建 Pose 对象
            pose_list = list(target_gripper_pose.p) + list(target_gripper_pose.q)
            pose_obj = mplib.pymp.Pose(pose_list[:3], pose_list[3:])
            
            # Get joint positions for active joints only
            joint_indices = [self.all_joints.index(name) for name in self.active_joints_name 
                            if name in self.all_joints]
            joint_angles = [curr_joint_pos[index] for index in joint_indices]
            joint_angles = [round(angle, 5) for angle in joint_angles]  # avoid precision problems
            joint_angles_array = np.array(joint_angles)
            
            # Call mplib planner with world coordinate pose
            result = self.mplib_planner.plan_path(
                now_qpos=joint_angles_array,
                target_pose=pose_obj,
                use_point_cloud=False,
                use_attach=False,
                arms_tag=arms_tag,
                log=False
            )
            
            # Ensure result format matches CuroboPlanner output
            if result.get("status") == "Success":
                if not isinstance(result.get("position"), np.ndarray):
                    result["position"] = np.array(result["position"])
                if not isinstance(result.get("velocity"), np.ndarray):
                    result["velocity"] = np.array(result["velocity"])
            
            return result
        
        # def plan_batch(self, curr_joint_pos, target_gripper_pose_list, constraint_pose=None, arms_tag=None):
        #     """Plan multiple paths - matching CuroboPlanner's interface"""
        #     num_poses = len(target_gripper_pose_list)
            
        #     # Transform all poses first (matching original implementation)
        #     world_base_pose = np.concatenate([
        #         np.array(self.robot_origion_pose.p),
        #         np.array(self.robot_origion_pose.q)
        #     ])
            
        #     poses_list = []
        #     for target_gripper_pose in target_gripper_pose_list:
        #         world_target_pose = np.concatenate([
        #             np.array(target_gripper_pose.p), 
        #             np.array(target_gripper_pose.q)
        #         ])
        #         base_target_pose_p, base_target_pose_q = self._trans_from_world_to_base(
        #             world_base_pose, world_target_pose
        #         )
        #         # Apply frame bias
        #         base_target_pose_list = list(base_target_pose_p) + list(base_target_pose_q)
        #         base_target_pose_list[0] += self.frame_bias[0]
        #         base_target_pose_list[1] += self.frame_bias[1]
        #         base_target_pose_list[2] += self.frame_bias[2]
        #         poses_list.append(base_target_pose_list)
            
        #     # Get joint angles for active joints
        #     joint_indices = [self.all_joints.index(name) for name in self.active_joints_name 
        #                     if name in self.all_joints]
        #     joint_angles = [curr_joint_pos[index] for index in joint_indices]
        #     joint_angles = [round(angle, 5) for angle in joint_angles]
            
        #     # Plan for each pose
        #     results = {
        #         "status": [],
        #         "position": [],
        #         "velocity": []
        #     }
            
        #     for pose in poses_list:
        #         print(f"joint_angles.shape:{np.array(joint_angles).shape}",flush=True)
        #         print(f"pose.shape:{np.array(pose).shape}",flush=True)
        #         print(f"joint_angles.:{joint_angles}",flush=True)
        #         print(f"pose.:{pose}",flush=True) #np.ndarray Pose,
        #         joint_angles = np.array(joint_angles)
        #         #pose = sapien.Pose(pose[:3], pose[3:]) test !!!
        #         pose = mplib.pymp.Pose(pose[:3], pose[3:]) 
                
        #         result = self.mplib_planner.plan_path(
        #             now_qpos=joint_angles,
        #             target_pose=pose,
        #             use_point_cloud=False,
        #             use_attach=False,
        #             arms_tag=arms_tag,
        #             log=False  # Don't log for batch planning
        #         )
                
        #         results["status"].append("Success" if result["status"] == "Success" else "Failure")
        #         if result["status"] == "Success":
        #             results["position"].append(result["position"])
        #             results["velocity"].append(result["velocity"])
        #         else:
        #             results["position"].append(np.array([]))
        #             results["velocity"].append(np.array([]))
            
        #     # Convert to proper format
        #     results["status"] = np.array(results["status"], dtype=object)
        #     results["position"] = np.array(results["position"], dtype=object)
        #     results["velocity"] = np.array(results["velocity"], dtype=object)
            
        #     print(f"plan_batch results are {results}",flush=True)
        #     return results
        
        #test !!!
        def plan_batch(self, curr_joint_pos, target_gripper_pose_list, constraint_pose=None, arms_tag=None):
            """Plan multiple paths - matching CuroboPlanner's interface"""
            #print(f"\n[DEBUG] ========== plan_batch for {arms_tag} ==========")
            
            num_poses = len(target_gripper_pose_list)
            
            # 先测试不进行坐标转换
            #test_without_transform = True  # 设置为 True 来测试
            
            #if test_without_transform:
            #print("[DEBUG] Testing WITHOUT coordinate transformation")
            poses_list = []
            for i, target_gripper_pose in enumerate(target_gripper_pose_list):
                # 直接使用原始位姿
                pose_list = list(target_gripper_pose.p) + list(target_gripper_pose.q)
                poses_list.append(pose_list)
                #print(f"[DEBUG] Original pose {i}: p={target_gripper_pose.p}, q={target_gripper_pose.q}")
            # else:
            #     print("[DEBUG] Using coordinate transformation")
            #     # 原来的转换代码...
            #     world_base_pose = np.concatenate([
            #         np.array(self.robot_origion_pose.p),
            #         np.array(self.robot_origion_pose.q)
            #     ])
                
            #     poses_list = []
            #     for target_gripper_pose in target_gripper_pose_list:
            #         world_target_pose = np.concatenate([
            #             np.array(target_gripper_pose.p), 
            #             np.array(target_gripper_pose.q)
            #         ])
            #         base_target_pose_p, base_target_pose_q = self._trans_from_world_to_base(
            #             world_base_pose, world_target_pose
            #         )
            #         # Apply frame bias
            #         base_target_pose_list = list(base_target_pose_p) + list(base_target_pose_q)
            #         base_target_pose_list[0] += self.frame_bias[0]
            #         base_target_pose_list[1] += self.frame_bias[1]
            #         base_target_pose_list[2] += self.frame_bias[2]
            #         poses_list.append(base_target_pose_list)
            
            # Get joint angles for active joints
            joint_indices = [self.all_joints.index(name) for name in self.active_joints_name 
                            if name in self.all_joints]
            joint_angles = [curr_joint_pos[index] for index in joint_indices]
            joint_angles = [round(angle, 5) for angle in joint_angles]
            
            # print(f"[DEBUG] active_joints_name: {self.active_joints_name}")
            # print(f"[DEBUG] joint_indices: {joint_indices}")
            # print(f"[DEBUG] joint_angles: {joint_angles}")
            
            # Plan for each pose
            results = {
                "status": [],
                "position": [],
                "velocity": []
            }
            #print(f"len of {len(poses_list)}",flush=True)
            for i, pose in enumerate(poses_list):
                # print(f"\n[DEBUG] Planning pose {i+1}/{len(poses_list)}")
                # print(f"[DEBUG] pose list: {pose}")
                
                joint_angles_array = np.array(joint_angles)
                
                # 测试不同的 Pose 创建方式
                #try:
                    # 方式1：直接使用列表
                pose_obj = mplib.pymp.Pose(pose[:3], pose[3:])
                    #print(f"[DEBUG] Created pose with list: {pose_obj}")
                # except Exception as e:
                #     print(f"[ERROR] Failed to create pose with list: {e}")
                    
                # try:
                #     # 方式2：尝试不同的四元数顺序 (如果原来是 wxyz，试试 xyzw)
                #     if len(pose) == 7:
                #         # 假设原来是 [x,y,z,w,x,y,z]，转换为 [x,y,z,x,y,z,w]
                #         pose_xyzw = pose[:3] + pose[4:] + [pose[3]]
                #         pose_obj2 = mplib.pymp.Pose(pose_xyzw[:3], pose_xyzw[3:])
                #         print(f"[DEBUG] Created pose with xyzw order: {pose_obj2}")
                # except Exception as e:
                #     print(f"[ERROR] Failed to create pose with xyzw order: {e}")
                
                # 检查 mplib planner 的状态
                # if hasattr(self.mplib_planner, 'planner'):
                #     print(f"[DEBUG] mplib planner type: {self.mplib_planner.planner_type}")
                #     print(f"[DEBUG] mplib planner move_group: {getattr(self.mplib_planner.planner, 'move_group', 'N/A')}")
                
                result = self.mplib_planner.plan_path(
                    now_qpos=joint_angles_array,
                    target_pose=pose_obj,
                    use_point_cloud=False,
                    use_attach=False,
                    arms_tag=arms_tag,
                    log=False  # 开启日志
                )
                
                #print(f"[DEBUG] Result: {result}")
                
                results["status"].append("Success" if result["status"] == "Success" else "Failure")
                if result["status"] == "Success":
                    results["position"].append(result["position"])
                    results["velocity"].append(result["velocity"])
                else:
                    results["position"].append(np.array([]))
                    results["velocity"].append(np.array([]))
            
            # Convert to proper format
            results["status"] = np.array(results["status"], dtype=object)
            results["position"] = np.array(results["position"], dtype=object)
            results["velocity"] = np.array(results["velocity"], dtype=object)
            
            return results
        
        
        def plan_grippers(self, now_val, target_val):
            """Plan gripper motion - simple linear interpolation"""
            num_step = 200
            dis_val = target_val - now_val
            per_step = dis_val / num_step
            res = {}
            vals = np.linspace(now_val, target_val, num_step)
            res["num_step"] = num_step
            res["per_step"] = per_step
            res["result"] = vals
            return res
        
        def update_point_cloud(self, pcd, resolution=0.02):
            """Update point cloud for collision checking"""
            # mplib doesn't support dynamic point cloud updates in the same way
            # This is a placeholder
            pass
        
        def _trans_from_world_to_base(self, base_pose, target_pose):
            """Transform from world frame to base frame"""
            base_p, base_q = base_pose[0:3], base_pose[3:]
            target_p, target_q = target_pose[0:3], target_pose[3:]
            rel_p = target_p - base_p
            wRb = t3d.quaternions.quat2mat(base_q)
            wRt = t3d.quaternions.quat2mat(target_q)
            result_p = wRb.T @ rel_p
            result_q = t3d.quaternions.mat2quat(wRb.T @ wRt)
            return result_p, result_q
        
except Exception as e:
    print('[planner.py]: Something wrong happened when importing CuroboPlanner! Please check if Curobo is installed correctly. If the problem still exists, you can install Curobo from https://github.com/NVlabs/curobo manually.',flush=True)
    print('Exception traceback:')
    traceback.print_exc()
    raise