import pybullet as p
import pybullet_data
import time
import numpy as np
from scipy.spatial.transform import Rotation
import torch
import utils.rot as rot

def to_torch(x, dtype=torch.float, requires_grad=False):
    return torch.tensor(x, dtype=dtype, device='cpu', requires_grad=requires_grad)

def quat_mul(a, b):
    assert a.shape == b.shape
    shape = a.shape
    a = a.reshape(-1, 4)
    b = b.reshape(-1, 4)

    x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
    x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
    ww = (z1 + x1) * (x2 + y2)
    yy = (w1 - y1) * (w2 + z2)
    zz = (w1 + y1) * (w2 - z2)
    xx = ww + yy + zz
    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
    w = qq - ww + (z1 - y1) * (y2 - z2)
    x = qq - xx + (x1 + w1) * (x2 + w2)
    y = qq - yy + (w1 - x1) * (y2 + z2)
    z = qq - zz + (z1 + y1) * (w2 - x2)

    quat = torch.stack([x, y, z, w], dim=-1).view(shape)

    return quat


def rotation_matrix_to_euler(rotation_matrix):
    sy = np.sqrt(rotation_matrix[0, 0] * rotation_matrix[0, 0] + rotation_matrix[1, 0] * rotation_matrix[1, 0])

    singular = sy < 1e-6

    if not singular:
        x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
        y = np.arctan2(-rotation_matrix[2, 0], sy)
        z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
    else:
        x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
        y = np.arctan2(-rotation_matrix[2, 0], sy)
        z = 0

    return np.array([x, y, z])


def euler_to_rotation_matrix(euler_angles):
    roll, pitch, yaw = euler_angles
    cos_roll = np.cos(roll)
    sin_roll = np.sin(roll)
    cos_pitch = np.cos(pitch)
    sin_pitch = np.sin(pitch)
    cos_yaw = np.cos(yaw)
    sin_yaw = np.sin(yaw)

    row1 = [cos_yaw * cos_pitch, cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll, cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll]
    row2 = [sin_yaw * cos_pitch, sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll, sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll]
    row3 = [-sin_pitch, cos_pitch * sin_roll, cos_pitch * cos_roll]

    rotation_matrix = np.array([row1, row2, row3])

    return rotation_matrix

class PybulletIKSolver():
    def __init__(self, left_urdf_file, right_urdf_file, hand_type):
        physicsClient = p.connect(p.GUI)
        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        self.left_robot_id = p.loadURDF(left_urdf_file, [0, 0, 0], useFixedBase=True)
        self.right_robot_id = p.loadURDF(right_urdf_file, [0, 0, 0], useFixedBase=True)
        p.resetDebugVisualizerCamera(cameraDistance=1.5, cameraYaw=50, cameraPitch=-35, cameraTargetPosition=[0, 0, 0])


        num_joints = p.getNumJoints(self.left_robot_id)
        joint_indices = range(num_joints)

        self.obj_id = p.loadURDF("/home/user/DexterousHandEnvs/assets/urdf/arctic_assets/object_urdf/box_for_pybullet.urdf", [0, 0, 0], useFixedBase=True)
        texture_path = "/home/user/DexterousHandEnvs/assets/urdf/arctic_assets/object_vtemplates/box/material.jpg"
        texture_id = p.loadTexture(texture_path)
        visual_shape_indices = -1  # 
        a = p.changeVisualShape(self.obj_id, visual_shape_indices, textureUniqueId=texture_id)
        visual_shape_indices = -1  # 
        rgba_color = [0, 1, 0, 1]  # 
        p.changeVisualShape(self.obj_id, -1, rgbaColor=rgba_color)
        p.changeVisualShape(self.left_robot_id, -1, rgbaColor=[0, 0, 1, 1])
        
        mano_p = "/home/user/arctic/data/arctic_data/data/raw_seqs/s01/box_use_01.mano.npy"
        obj_p = "/home/user/arctic/data/arctic_data/data/raw_seqs/s01/box_use_01.object.npy"
        data = np.load(
                    mano_p,
                    allow_pickle=True,
                ).item()

        self.obj_params = torch.FloatTensor(np.load(obj_p, allow_pickle=True))
        self.obj_params[:, 4:7] /= 1000
                        
        mano_processed = (
            "/home/user/arctic/outputs/processed/seqs/s01/box_use_01.npy")
        mano_processed_data = np.load(
            mano_processed,
            allow_pickle=True,
        ).item()
        view_idx = 1
        
        cam2world_matrix = (
            torch.tensor(
                [
                    [0.8946, -0.4464, 0.0197, 0.1542],
                    [-0.1109, -0.2646, -0.9580, 0.9951],
                    [0.4328, 0.8548, -0.2862, 4.6415],
                    [0.0000, 0.0000, 0.0000, 1.0000],
                ]
            )[:3, :3]
            .inverse()
        )

        quat_cam2world = rot.matrix_to_quaternion(cam2world_matrix)
        rot_l_cam = rot.axis_angle_to_quaternion(
            torch.FloatTensor(
                mano_processed_data["cam_coord"]["rot_l_cam"][:, view_idx, :]
            )
        )
        rot_l_world = rot.quaternion_to_axis_angle(
                                rot.quaternion_multiply(quat_cam2world, rot_l_cam)
                            )
        rot_l_quat = rot.axis_angle_to_quaternion(rot_l_world).numpy()
        
        rot_r_cam = rot.axis_angle_to_quaternion(
            torch.FloatTensor(
                mano_processed_data["cam_coord"]["rot_r_cam"][:, view_idx, :]
            )
        )
        rot_r_world = rot.quaternion_to_axis_angle(
                                rot.quaternion_multiply(quat_cam2world, rot_r_cam)
                            )
        rot_r_quat = rot.axis_angle_to_quaternion(rot_r_world).numpy()
        
        obj_r_cam = rot.axis_angle_to_quaternion(
            torch.FloatTensor(
                mano_processed_data["cam_coord"]["obj_rot_cam"][:, view_idx, :]
            )
        )
        obj_r_world = rot.quaternion_to_axis_angle(
            rot.quaternion_multiply(quat_cam2world, obj_r_cam)
        )
        obj_rot_quat = rot.axis_angle_to_quaternion(obj_r_world).numpy()
        self.obj_params[:, 1:4] = obj_r_world
        # obj_params[:, 4:7] /= 1000

        obj_rot_quat_tem = obj_rot_quat.copy()
        obj_rot_quat[:, 0] = obj_rot_quat_tem[:, 1].copy()
        obj_rot_quat[:, 1] = obj_rot_quat_tem[:, 2].copy()
        obj_rot_quat[:, 2] = obj_rot_quat_tem[:, 3].copy()
        obj_rot_quat[:, 3] = obj_rot_quat_tem[:, 0].copy()

        self.obj_rot_quat = obj_rot_quat

        left_init_mano = list(mano_processed_data["world_coord"]["joints.left"][0][16:21])
        right_init_mano = list(mano_processed_data["world_coord"]["joints.right"][0][16:21])

        self.left_initial_finger_positions = {
            "thumb_tip": left_init_mano[0],
            "index_tip": left_init_mano[1],
            "middle_tip": left_init_mano[2],
            "ring_tip": left_init_mano[3],
            "pinky_tip": left_init_mano[4],
        }
        self.right_initial_finger_positions = {
            "thumb_tip": right_init_mano[0],
            "index_tip": right_init_mano[1],
            "middle_tip": right_init_mano[2],
            "ring_tip": right_init_mano[3],
            "pinky_tip": right_init_mano[4],
        }
        
        ball_radius = 0.005
        self.ball_ids = []
        for finger, pos in self.left_initial_finger_positions.items():
            visual_shape_id = p.createVisualShape(shapeType=p.GEOM_SPHERE, radius=ball_radius, rgbaColor=[1, 0, 0, 1])
            collision_shape_id = p.createCollisionShape(shapeType=p.GEOM_SPHERE, radius=ball_radius)
            ball_id = p.createMultiBody(baseMass=0, baseCollisionShapeIndex=collision_shape_id, baseVisualShapeIndex=visual_shape_id, basePosition=pos)
            self.ball_ids.append(ball_id)

        for finger, pos in self.right_initial_finger_positions.items():
            visual_shape_id = p.createVisualShape(shapeType=p.GEOM_SPHERE, radius=ball_radius, rgbaColor=[1, 0, 0, 1])
            collision_shape_id = p.createCollisionShape(shapeType=p.GEOM_SPHERE, radius=ball_radius)
            ball_id = p.createMultiBody(baseMass=0, baseCollisionShapeIndex=collision_shape_id, baseVisualShapeIndex=visual_shape_id, basePosition=pos)
            self.ball_ids.append(ball_id)

        # 设置机器人的初始位置和旋转（四元数）
        self.left_global_position = mano_processed_data["world_coord"]["joints.left"][0, 0]
        self.right_global_position = mano_processed_data["world_coord"]["joints.right"][0, 0]
        
        self.left_initial_orientation = rot_l_quat[0]
        self.right_initial_orientation = rot_r_quat[0]
        # 重置机器人的基座位置和旋转
        p.resetBasePositionAndOrientation(self.left_robot_id, self.left_global_position, self.left_initial_orientation)
        p.resetBasePositionAndOrientation(self.right_robot_id, self.right_global_position, self.right_initial_orientation)

        # 运行一次仿真，使系统初始化到第一帧
        p.stepSimulation()
        self.steps=0

        self.hand_lower_limits, self.hand_upper_limits, self.hand_joint_ranges = self.get_joint_limits(self.left_robot_id)
        self.left_my_restPoses=[0] * num_joints
        self.right_my_restPoses=[0] * num_joints
        self.num_joints = num_joints
        self.hand_type = hand_type
        # 获取末端位置
        if hand_type == "mano":
            self.efi=[4, 8, 12, 16, 20]
            self.slide_range = [[0, 4], [4, 8], [8, 12], [12, 16], [16, 20]]
        if hand_type == "shadow":
            self.efi=[6, 11, 16, 21, 27]
            self.slide_range = [[0, 5], [5, 9], [9, 13], [13, 17], [17, 22]]

    def solve_ik(self, left_fingertip_position, left_base_position, left_base_rotation, right_fingertip_position, right_base_position, right_base_rotation):
        left_fingertip_position.cpu().numpy()
        end_effector_positions = list(left_fingertip_position)

        left_jointPoses = []
        # 对每个末端位置求解逆运动学并可视化
        for i, end_effector_pos in enumerate(end_effector_positions):
            # 求解逆运动学
            left_jointPoses = left_jointPoses + list(p.calculateInverseKinematics(self.left_robot_id, self.efi[i], end_effector_pos, jointRanges=self.hand_joint_ranges, lowerLimits=self.hand_lower_limits, upperLimits=self.hand_upper_limits,
                                                            residualThreshold=0.001, maxNumIterations=1000, restPoses=self.left_my_restPoses))[self.slide_range[i][0]:self.slide_range[i][1]]    
            # 设置关节角度
            
        left_jointPoses = tuple(left_jointPoses)
        if self.hand_type == "mano":
            left_jointPoses = ((0.0,) + left_jointPoses[0:4] + left_jointPoses[4:8] + left_jointPoses[8:12] + left_jointPoses[12:16] + left_jointPoses[16:20])
        elif self.hand_type == "shadow":
            left_jointPoses = ((0.0,) + left_jointPoses[0:5] + (0.0,) + left_jointPoses[5:9] + (0.0,) + left_jointPoses[9:13] + (0.0,) + left_jointPoses[13:17] + (0.0,) + left_jointPoses[17:22] + (0.0,))

        p.resetBasePositionAndOrientation(self.left_robot_id, left_base_position, left_base_rotation)

        for k in range(self.num_joints):
            p.resetJointState(self.left_robot_id, k, left_jointPoses[k])

        for k_i, finger in enumerate(self.left_initial_finger_positions.keys()):
                p.resetBasePositionAndOrientation(self.ball_ids[k_i], end_effector_positions[k_i], [0, 0, 0, 1])

        ##########################
        # right hand
        #########################
        right_fingertip_position.cpu().numpy()
        end_effector_positions = list(right_fingertip_position)

        right_jointPoses = []
        # 对每个末端位置求解逆运动学并可视化
        for i, end_effector_pos in enumerate(end_effector_positions):
            # 求解逆运动学
            right_jointPoses = right_jointPoses + list(p.calculateInverseKinematics(self.right_robot_id, self.efi[i], end_effector_pos, jointRanges=self.hand_joint_ranges, lowerLimits=self.hand_lower_limits, upperLimits=self.hand_upper_limits,
                                                            residualThreshold=0.001, maxNumIterations=1000, restPoses=self.right_my_restPoses))[self.slide_range[i][0]:self.slide_range[i][1]]    
            # 设置关节角度
            
        right_jointPoses = tuple(right_jointPoses)
        if self.hand_type == "mano":
            right_jointPoses = ((0.0,) + right_jointPoses[0:4] + right_jointPoses[4:8] + right_jointPoses[8:12] + right_jointPoses[12:16] + right_jointPoses[16:20])
        elif self.hand_type == "shadow":
            right_jointPoses = ((0.0,) + right_jointPoses[0:5] + (0.0,) + right_jointPoses[5:9] + (0.0,) + right_jointPoses[9:13] + (0.0,) + right_jointPoses[13:17] + (0.0,) + right_jointPoses[17:22] + (0.0,))

        p.resetBasePositionAndOrientation(self.right_robot_id, right_base_position, right_base_rotation)

        for k in range(self.num_joints):
            p.resetJointState(self.right_robot_id, k, right_jointPoses[k])

        for k_i, finger in enumerate(self.right_initial_finger_positions.keys()):
                p.resetBasePositionAndOrientation(self.ball_ids[k_i+5], end_effector_positions[k_i], [0, 0, 0, 1])

        # 更新仿真并等待一段时间
        p.stepSimulation()
        # while True:
        #     a=1
        # input()
        self.steps+=1

        self.right_my_restPoses = list(right_jointPoses)
        self.left_my_restPoses = list(left_jointPoses)
        left_jointPoses, right_jointPoses = self.mano_to_isaacgym_dof_index(left_jointPoses, right_jointPoses)
        
        return left_jointPoses, right_jointPoses

    def mano_to_isaacgym_dof_index(self, left_jointPoses, right_jointPoses, test=False):
    # {'joint10': 0, 'joint11': 1, 'joint12': 2, 'joint13': 3, 'joint14': 4,
    # 'joint15': 5, 'joint16': 6, 'joint17': 7, 'joint18': 8, 'joint19': 9, 'joint2':12, 
    # 'joint20': 10, 'joint21': 11, 'joint3': 13, 'joint4': 14, 'joint5': 15, 'joint6': 16, 'joint7': 17, 'joint8': 18, 'joint9': 19}
        if self.hand_type == "mano":
            right_jointPoses = [right_jointPoses[i+1] for i in [16, 17, 18, 19, 12, 13, 14, 15, 0, 1, 4, 2, 3, 5, 6, 7, 8, 9, 10, 11]]
            left_jointPoses = [left_jointPoses[i+1] for i in [16, 17, 18, 19, 12, 13, 14, 15, 0, 1, 4, 2, 3, 5, 6, 7, 8, 9, 10, 11]]
            
    # {'FFJ1': 3, 'FFJ2': 2, 'FFJ3': 1, 'FFJ4': 0, 'LFJ1': 8, 'LFJ2': 7, 
    # 'LFJ3': 6, 'LFJ4': 5, 'LFJ5': 4, 'MFJ1': 12, 'MFJ2': 11, 'MFJ3': 10, 
    # 'MFJ4': 9, 'RFJ1': 16, 'RFJ2': 15, 'RFJ3': 14, 'RFJ4': 13, 'THJ1': 21, 
    # 'THJ2': 20, 'THJ3': 19, 'THJ4': 18, 'THJ5': 17}
        if self.hand_type == "shadow":
            right_jointPoses = [right_jointPoses[i] for i in [7, 8, 9, 10, 22, 23, 24, 25, 26, 12, 13, 14, 15, 17, 18, 19, 20, 1, 2, 3, 4, 5]]
            left_jointPoses = [left_jointPoses[i] for i in [7, 8, 9, 10, 22, 23, 24, 25, 26, 12, 13, 14, 15, 17, 18, 19, 20, 1, 2, 3, 4, 5]]

        return left_jointPoses, right_jointPoses
    
    def get_joint_limits(self, robot):
        joint_lower_limits = []
        joint_upper_limits = []
        joint_ranges = []
        for i in range(p.getNumJoints(robot)):
            joint_info = p.getJointInfo(robot, i)
            if joint_info[2] == p.JOINT_FIXED:
                continue
            joint_lower_limits.append(joint_info[8])
            joint_upper_limits.append(joint_info[9])
            joint_ranges.append(joint_info[9] - joint_info[8])
        return joint_lower_limits, joint_upper_limits, joint_ranges
# 调用示例