import os
import sys
import torch
import numpy as np
from typing import List, Optional
from termcolor import cprint

from isaacsim.core.api import World
from isaacsim.core.prims import SingleXFormPrim, SingleRigidPrim
from isaacsim.core.api.robots import Robot
from isaacsim.core.api.objects import DynamicCuboid, VisualCuboid
from isaacsim.core.utils.stage import add_reference_to_stage, get_stage_units
from isaacsim.core.utils.rotations import euler_angles_to_quat
from isaacsim.core.utils.types import ArticulationAction
from isaacsim.robot.manipulators.examples.franka import Franka as BaseFranka
from isaacsim.robot.manipulators.grippers.parallel_gripper import ParallelGripper
from isaacsim.robot_motion.motion_generation.lula.motion_policies import RmpFlow
from isaacsim.robot_motion.motion_generation.interface_config_loader import load_supported_motion_policy_config
from isaacsim.robot_motion.motion_generation.articulation_motion_policy import ArticulationMotionPolicy

sys.path.append(os.getcwd())
from Env_Config.Utils_Project.Set_Drive import set_drive
from Env_Config.Utils_Project.Transforms import quat_diff_rad, Rotation, get_pose_relat, get_pose_world
from Env_Config.Utils_Project.Code_Tools import float_truncate, dense_trajectory_points_generation

class Franka(Robot):
    def __init__(
        self,
        world: World,
        position: np.ndarray,
        orientation: np.ndarray,
        scale: np.ndarray,
        robot_name: str = "Franka",
        gripper_open_position_nounits=np.array([0.05, 0.05])
    ) -> None:
        self.world = world
        self._name = robot_name
        self._prim_path = "/World/" + self._name

        self.asset_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../Assets/Robots/franka/franka.usd")

        self.position = position
        self.scale = scale
        self.orientation = euler_angles_to_quat(orientation, degrees=True)
        self.gripper_open_position_nounits = gripper_open_position_nounits

        add_reference_to_stage(self.asset_file, self._prim_path)

        super().__init__(
            prim_path=self._prim_path,
            name=self._name,
            position=self.position,
            orientation=self.orientation,
            articulation_controller=None,
            scale=self.scale,
        )

        self._end_effector_prim_path = self._prim_path + "/panda_hand"

        gripper_dof_names = ["panda_finger_joint1", "panda_finger_joint2"]
        gripper_open_position = self.gripper_open_position_nounits / get_stage_units()
        gripper_closed_position = np.array([0.0, 0.0])
        deltas = np.array([0.05, 0.05]) / get_stage_units()

        self._gripper = ParallelGripper(
            end_effector_prim_path=self._end_effector_prim_path,
            joint_prim_names=gripper_dof_names,
            joint_opened_positions=gripper_open_position,
            joint_closed_positions=gripper_closed_position,
            action_deltas=deltas,
        )

        self.world.scene.add(self)

        self.rmp_flow_config = load_supported_motion_policy_config("Franka", "RMPflow")
        self.rmp_flow = RmpFlow(** self.rmp_flow_config)
        self.rmp_flow.set_robot_base_pose(
            self.position, self.orientation
        )
        self.articulation_rmp = ArticulationMotionPolicy(self, self.rmp_flow, 1.0 / 60.0)
        self.articulation_controller = self.get_articulation_controller()

        self.pre_error = 0.0
        self.error_nochange_epoch = 0

        return

    def initialize(self, physics_sim_view=None) -> None:
        super().initialize(physics_sim_view)
        self._end_effector = SingleRigidPrim(prim_path=self._end_effector_prim_path, name=self.name + "_end_effector")
        self._end_effector.initialize(physics_sim_view)
        self._gripper.initialize(
            physics_sim_view=physics_sim_view,
            articulation_apply_action_func=self.apply_action,
            get_joint_positions_func=self.get_joint_positions,
            set_joint_positions_func=self.set_joint_positions,
            dof_names=self.dof_names,
        )
        self.disable_gravity()

        return

    def post_reset(self) -> None:
        super().post_reset()
        self._gripper.post_reset()
        self._articulation_controller.switch_dof_control_mode(
            dof_index=self.gripper.joint_dof_indicies[0], mode="position"
        )
        self._articulation_controller.switch_dof_control_mode(
            dof_index=self.gripper.joint_dof_indicies[1], mode="position"
        )
        return

    @property
    def end_effector(self) -> SingleRigidPrim:
        return self._end_effector

    @property
    def gripper(self) -> ParallelGripper:
        return self._gripper

    def open_gripper(self) -> None:
        self.gripper.open()
        for i in range(20):
            self.world.step()
        return

    def close_gripper(self) -> None:
        self.gripper.close()
        for i in range(20):
            self.world.step()
        return

    def get_cur_ee_pos(self):
        position, orientation = self.end_effector.get_world_pose()
        return position, orientation

    def add_obstacle(self, obstacle):
        self.rmp_flow.add_obstacle(obstacle, False)
        for i in range(10):
            self.world.step(render=True)
        return

    def Rmpflow_Step_Action(self, position, orientation=None):
        self.world.step(render=True)

        self.rmp_flow.set_end_effector_target(
            target_position=position, target_orientation=orientation
        )

        self.rmp_flow.update_world()
        actions = self.articulation_rmp.get_next_articulation_action()

        self._articulation_controller.apply_action(actions)

    def Rmpflow_Move(self, target_position, target_orientation=np.array([180.0, 0.0, 180.0])):
        target_ee_position = target_position

        if target_orientation is None:
            target_ee_orientation = None
        else:
            target_ee_orientation = euler_angles_to_quat(target_orientation, degrees=True)

        self.error_nochange_epoch = 0
        self.pre_error = 0.0

        cprint(f"[Info]: RMPflow Move to {target_position}, RMPflow Orientation {target_ee_orientation}", "white", attrs=["bold"])

        while True:
            pos, ori = self.get_cur_ee_pos()
            gripper_pos = pos + Rotation(ori, np.array([0.0, 0.0, 0.1]))

            error = np.linalg.norm(target_position - gripper_pos)

            error_gap = abs(error - self.pre_error)
            self.pre_error = error

            if error_gap < 0.0002:
                self.error_nochange_epoch += 1

            if self.error_nochange_epoch > 100:
                cprint(f"[Info]: Single Franka RMPflow controller failed (stuck for {self.error_nochange_epoch} steps)", "red")
                return False

            if error < 0.001:
                cprint("[Info]: Single Franka RMPflow controller success", "green")
                return True

            self.Rmpflow_Step_Action(target_ee_position, target_ee_orientation)

    def Dense_Rmpflow_Move(
        self,
        target_position,
        target_orientation=np.array([180.0, 0.0, 180.0]),
        dense_sample_scale=0.02,
    ):
        if target_orientation is not None:
            target_orientation = euler_angles_to_quat(target_orientation, degrees=True)

        pos, ori = self.get_cur_ee_pos()
        gripper_pos = pos + Rotation(ori, np.array([0.0, 0.0, 0.1]))

        dense_sample_num = int(np.linalg.norm(gripper_pos - target_position) // dense_sample_scale)

        interp_pos = dense_trajectory_points_generation(
            start_pos=gripper_pos,
            end_pos=target_position,
            num_points=dense_sample_num,
        )

        cprint("--Single Franka Dense_Rmpflow_Move Begin", "green")
        for i in range(len(interp_pos)):
            print(f"-------step {i}-------")
            for j in range(5):
                self.Rmpflow_Step_Action(interp_pos[i], target_orientation)
                self.world.step()
        cprint("--Single Franka Dense_Rmpflow_Move End", "green")
        return

    def move_block_follow_gripper(
        self, attach_block, target_position, target_orientation=None
    ):
        self.Rmpflow_Move(position=target_position, orientation=target_orientation)
        gripper_position, gripper_orientation = self.get_cur_grip_pos()
        block_position = gripper_position + self.Rotation(
            gripper_orientation, torch.Tensor([0.0, 0.0, 0.05])
        )
        attach_block.block.set_world_pose(block_position, gripper_orientation)
        self.world.step(render=True)