import os
import sys
import numpy as np
from termcolor import cprint

from isaacsim.core.api import World
from isaacsim.core.utils.rotations import euler_angles_to_quat
from pxr import UsdPhysics, Gf, UsdGeom, Usd
from omni.isaac.core.utils.prims import delete_prim

sys.path.append(os.getcwd())
from Env_Config.Robot.Franka import Franka
from Env_Config.Utils_Project.Transforms import Rotation
from Env_Config.Utils_Project.Code_Tools import dense_trajectory_points_generation

class Bimanual_Franka:
    def __init__(self, world:World, left_pos, left_ori, right_pos, right_ori, 
    left_franka_scale=np.array([1.0, 1.0, 1.0]),
    right_franka_scale=np.array([1.0, 1.0, 1.0]),
    gripper_open_position_nounits=np.array([0.05, 0.05])
    ):
        self.world = world
        self.left_franka = Franka(world, left_pos, left_ori, robot_name="Franka_Left", scale=left_franka_scale,gripper_open_position_nounits=gripper_open_position_nounits)
        self.right_franka = Franka(world, right_pos, right_ori, robot_name="Franka_Right", scale=right_franka_scale,gripper_open_position_nounits=gripper_open_position_nounits)

        self.left_pre_error = 0.0
        self.left_error_nochange_epoch = 0
        
        self.right_pre_error = 0.0
        self.right_error_nochange_epoch = 0
    
    def attach_board_with_grippers(self, board_prim_path: str):
        stage = self.world.stage
        
        left_gripper_prim = self.left_franka.end_effector.prim
        board_prim = stage.GetPrimAtPath(board_prim_path)

        if not board_prim.IsValid() or not left_gripper_prim.IsValid():
            cprint(f"Error: Prims for joint are not valid.", "red")
            return None, None

        board_xform = UsdGeom.Xformable(board_prim)
        left_gripper_xform = UsdGeom.Xformable(left_gripper_prim)
        
        board_world_transform = board_xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())
        left_gripper_world_transform = left_gripper_xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())
        
        relative_transform = board_world_transform * left_gripper_world_transform.GetInverse()
        
        relative_pos = relative_transform.ExtractTranslation()
        relative_rot = relative_transform.ExtractRotation().GetQuat()

        left_joint_path = f"{left_gripper_prim.GetPath()}/board_fix_joint"
        left_joint = UsdPhysics.FixedJoint.Define(stage, left_joint_path)
        left_joint.GetBody0Rel().SetTargets([left_gripper_prim.GetPath()])
        left_joint.GetBody1Rel().SetTargets([board_prim.GetPath()])
        
        left_joint.GetLocalPos0Attr().Set(Gf.Vec3f(0, 0, 0))
        left_joint.GetLocalRot0Attr().Set(Gf.Quatf(1.0))
        
        left_joint.GetLocalPos1Attr().Set(Gf.Vec3f(relative_pos))
        left_joint.GetLocalRot1Attr().Set(Gf.Quatf(relative_rot.GetReal(), *relative_rot.GetImaginary()))
        
        cprint(f"✓ Left gripper attached to board via FixedJoint at '{left_joint_path}'", "green")

        right_gripper_prim = self.right_franka.end_effector.prim
        if not right_gripper_prim.IsValid():
            cprint(f"Error: Right gripper prim is not valid.", "red")
            return None, None

        right_gripper_xform = UsdGeom.Xformable(right_gripper_prim)
        right_gripper_world_transform = right_gripper_xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())
        
        relative_transform_right = board_world_transform * right_gripper_world_transform.GetInverse()
        relative_pos_right = relative_transform_right.ExtractTranslation()
        relative_rot_right = relative_transform_right.ExtractRotation().GetQuat()

        right_joint_path = f"{right_gripper_prim.GetPath()}/board_fix_joint"
        right_joint = UsdPhysics.FixedJoint.Define(stage, right_joint_path)
        right_joint.GetBody0Rel().SetTargets([right_gripper_prim.GetPath()])
        right_joint.GetBody1Rel().SetTargets([board_prim.GetPath()])

        right_joint.GetLocalPos0Attr().Set(Gf.Vec3f(0, 0, 0))
        right_joint.GetLocalRot0Attr().Set(Gf.Quatf(1.0))
        right_joint.GetLocalPos1Attr().Set(Gf.Vec3f(relative_pos_right))
        right_joint.GetLocalRot1Attr().Set(Gf.Quatf(relative_rot_right.GetReal(), *relative_rot_right.GetImaginary()))

        cprint(f"✓ Right gripper attached to board via FixedJoint at '{right_joint_path}'", "green")

        return [left_joint_path, right_joint_path]

    def detach_board_joints(self, joint_ids: list):
        for joint_path in joint_ids:
            if joint_path:
                delete_prim(joint_path)
                cprint(f"✓ Joint '{joint_path}' deleted.", "green")

    def Gripper_Both_Open(self):
        self.left_franka.gripper.open()
        self.right_franka.gripper.open()
        for i in range(20):
            self.world.step(render=True)
        cprint("[Info]: Gripper_Both_Open", "white", attrs=["bold"])
            
    def Gripper_Both_Close(self):
        self.left_franka.gripper.close()
        self.right_franka.gripper.close()
        for i in range(20):
            self.world.step(render=True)
        cprint("Gripper_Both_Close", "green")
            
    def Gripper_Left_Open(self):
        self.left_franka.gripper.open()
        for i in range(20):
            self.world.step(render=True)
        cprint("Gripper_Left_Open", "green")
            
    def Gripper_Left_Close(self):
        self.left_franka.gripper.close()
        for i in range(20):
            self.world.step(render=True)
        cprint("Gripper_Left_Close", "green")
            
    def Gripper_Right_Open(self):
        self.right_franka.gripper.open()
        for i in range(20):
            self.world.step(render=True)
        cprint("Gripper_Right_Open", "green")
            
    def Gripper_Right_Close(self):
        cprint("Gripper_Right_Close  Be ready", "green")
        
        self.right_franka.gripper.close()
        cprint("Gripper_Right_Close OK！", "green")
        for i in range(20):
            self.world.step(render=True)
        cprint("Gripper_Right_Close", "green")
            
    def Rmpflow_Left_Move(
        self,
        target_position,
        target_orientation=np.array([180.0, 0.0, 0.0]),
    ):
        self.left_franka.Rmpflow_Move(target_position, target_orientation)
        
    def Rmpflow_Right_Move(
        self,
        target_position,
        target_orientation=np.array([180.0, 0.0, 180.0]),
    ):
        self.right_franka.Rmpflow_Move(target_position, target_orientation)

    def Rmpflow_Both_Move(
        self,
        left_target_position, 
        right_target_position,
        left_target_orientation=None,
        right_target_orientation=None,
    ):
        cprint(f'left_target_position:{left_target_position},right_target_position:{right_target_position}',"white")
        left_target_ee_position = left_target_position
        right_target_ee_position = right_target_position
        if left_target_orientation is not None:
            left_target_ee_orientation = euler_angles_to_quat(left_target_orientation, degrees=True)
        else:
            left_target_ee_orientation = None
        if right_target_orientation is not None:
            right_target_ee_orientation = euler_angles_to_quat(right_target_orientation, degrees=True)
        else:
            right_target_ee_orientation = None
            
        self.left_error_nochange_epoch = 0
        self.right_error_nochange_epoch = 0
        self.left_pre_error = 0.0
        self.right_pre_error = 0.0

        while True:
            left_pos, left_ori = self.left_franka.get_cur_ee_pos()
            left_gripper_pos = left_pos + Rotation(left_ori, np.array([0.0, 0.0, 0.1]))
            left_error = np.linalg.norm(left_target_position - left_gripper_pos)
            left_error_gap = abs(left_error - self.left_pre_error)
            self.left_pre_error = left_error
            if left_error_gap < 1e-4:
                self.left_error_nochange_epoch += 1
                
            right_pos, right_ori = self.right_franka.get_cur_ee_pos()
            right_gripper_pos = right_pos + Rotation(right_ori, np.array([0.0, 0.0, 0.1]))
            right_error = np.linalg.norm(right_target_position - right_gripper_pos)
            right_error_gap = abs(right_error - self.right_pre_error)
            self.right_pre_error = right_error
            if right_error_gap < 1e-4:
                self.right_error_nochange_epoch += 1
                
            if self.left_error_nochange_epoch > 50 and self.right_error_nochange_epoch > 50:
                cprint("Both Frankas RMPflow Controller failed", "red")
                return 100
            elif self.left_error_nochange_epoch > 100 and right_error < 0.001:
                cprint("Only Right Franka RMPflow Controller succeeded", "yellow")
                return False
            elif self.right_error_nochange_epoch > 100 and left_error < 0.001:
                cprint("Only Left Franka RMPflow Controller succeeded", "yellow")
                return False
            elif left_error < 0.001 and right_error < 0.001:
                cprint("Both Frankas RMPflow Controller succeeded", "green")
                return True
            
            self.left_franka.Rmpflow_Step_Action(left_target_ee_position, left_target_ee_orientation)
            self.right_franka.Rmpflow_Step_Action(right_target_ee_position, right_target_ee_orientation)
            
    def Dense_Rmpflow_Left_Move(
        self, 
        target_position, 
        target_orientation=np.array([180.0, 0.0, 0.0]),
        dense_sample_scale=0.02,
    ):
        self.left_franka.Dense_Rmpflow_Move(target_position, target_orientation, dense_sample_scale)
        
    def Dense_Rmpflow_Right_Move(
        self, 
        target_position, 
        target_orientation=np.array([180.0, 0.0, 180.0]),
        dense_sample_scale=0.02,
    ):
        self.right_franka.Dense_Rmpflow_Move(target_position, target_orientation, dense_sample_scale)
        
    def Dense_Rmpflow_Both_Move(
        self, 
        left_target_position, 
        right_target_position, 
        left_target_orientation=np.array([180.0, 0.0, 0.0]),
        right_target_orientation=np.array([180.0, 0.0, 180.0]),
        dense_sample_scale=0.02,
    ):
        if left_target_orientation is not None:
            left_target_ee_orientation = euler_angles_to_quat(left_target_orientation, degrees=True)
        if right_target_orientation is not None:
            right_target_ee_orientation = euler_angles_to_quat(right_target_orientation, degrees=True)
        
        left_pos, left_ori = self.left_franka.get_cur_ee_pos()
        right_pos, right_ori = self.right_franka.get_cur_ee_pos()
        
        left_gripper_pos = left_pos + Rotation(left_ori, np.array([0.0, 0.0, 0.1]))
        right_gripper_pos = right_pos + Rotation(right_ori, np.array([0.0, 0.0, 0.1]))
        
        left_dense_sample_num = int(np.linalg.norm(left_target_position - left_gripper_pos) // dense_sample_scale)
        right_dense_sample_num = int(np.linalg.norm(right_target_position - right_gripper_pos) // dense_sample_scale)
        dense_sample_num = max(left_dense_sample_num, right_dense_sample_num)
        
        left_interp_pos = dense_trajectory_points_generation(
            start_pos=left_gripper_pos,
            end_pos=left_target_position,
            num_points=dense_sample_num,
        )
        right_interp_pos = dense_trajectory_points_generation(
            start_pos=right_gripper_pos,
            end_pos=right_target_position,
            num_points=dense_sample_num,
        )

        cprint("--Both Frankas Dense_Rmpflow_Move Begin", "green")
        for i in range(len(left_interp_pos)):
            print(f"-------step {i}-------")
            for j in range(5):
                self.left_franka.Rmpflow_Step_Action(left_interp_pos[i], left_target_ee_orientation)
                self.right_franka.Rmpflow_Step_Action(right_interp_pos[i], right_target_ee_orientation)
                self.world.step()
        cprint("--Both Frankas Dense_Rmpflow_Move End", "green")

    def get_left_gripper_pose(self):
        left_ee_pos, left_ee_ori = self.left_franka.get_cur_ee_pos()
        left_gripper_pos = left_ee_pos + Rotation(left_ee_ori, np.array([0.0, 0.0, 0.1]))
        return left_gripper_pos, left_ee_ori
    
    def get_right_gripper_pose(self):
        right_ee_pos, right_ee_ori = self.right_franka.get_cur_ee_pos()
        right_gripper_pos = right_ee_pos + Rotation(right_ee_ori, np.array([0.0, 0.0, 0.1]))
        return right_gripper_pos, right_ee_ori