from motionblender.app.utils import RobotInterface, MotionBlender
import roma
import motionblender.lib.misc as misc
import torch.nn.functional as F
from loguru import logger
import torch.nn as nn
import torch.optim as optim
import numpy as np
import motionblender.lib.animate as anim
from jaxtyping import Float32
from tqdm.auto import tqdm, trange
import kinpy as kp
import torch
from torch import Tensor

def map_gripper_ctrl_to_joint_v(gv):
    v = max(min(gv, 100), 30)
    v = 0.8 * (v - 30) / 70
    return {
        'finger_1_joint_1': v,
        'finger_2_joint_1': v,
        'finger_middle_joint_1': v
    }

robot_joints = [
    'joint_0',
    'joint_1',
    'joint_2',
    'joint_3',
    'joint_4',
    'joint_5',
    'joint_6',
    'joint_7', 
    'ee',
    'palm',
    'left_finger_1',
    'right_finger_1',
    'left_finger_2',
    'right_finger_2',
]

robot_joints_from_origin = {
    'joint_0': 'iiwa_link_0',
    'joint_1': 'iiwa_link_1',
    'joint_2': 'iiwa_link_2',
    'joint_3': 'iiwa_link_3',
    'joint_4': 'iiwa_link_4',
    'joint_5': 'iiwa_link_5',
    'joint_6': 'iiwa_link_6',
    'joint_7': 'iiwa_link_7',
    'ee': 'iiwa_link_ee',
    'palm': 'palm',
    'left_finger_1': ['finger_middle_link_0'],
    'right_finger_1': ['finger_1_link_0', 'finger_2_link_0'],
    'left_finger_2': ['finger_middle_link_3'],
    'right_finger_2': ['finger_1_link_3', 'finger_2_link_3'],
}

robot_connections = [
    ['joint_0', 'joint_1'],
    ['joint_1', 'joint_2'],
    ['joint_2', 'joint_3'],
    ['joint_3', 'joint_4'],
    ['joint_4', 'joint_5'],
    ['joint_5', 'joint_6'],
    ['joint_6', 'joint_7'],
    ['joint_7', 'ee'],
    ['ee', 'palm'],
    ['palm', 'left_finger_1'],
    ['palm', 'right_finger_1'],
    ['left_finger_1', 'left_finger_2'],
    ['right_finger_1', 'right_finger_2'],
]

robot_connections_int = [(robot_joints.index(a), robot_joints.index(b)) for a, b in robot_connections]


default_joint_values = {'iiwa_joint_1': -0.15783709287643433,
 'iiwa_joint_2': 0.48583802580833435,
 'iiwa_joint_3': 1.0546152225288097e-05,
 'iiwa_joint_4': -1.6795016527175903,
 'iiwa_joint_5': 0.9391155242919922,
 'iiwa_joint_6': 1.027316689491272,
 'iiwa_joint_7': -1.273979663848877,
 'palm_finger_1_joint': -0.16,
 'palm_finger_2_joint': 0.16,
 'finger_middle_joint_3': 0,
 'finger_middle_joint_2': 0,
 'finger_2_joint_3': 0,
 'finger_2_joint_2': 0,
 'finger_1_joint_3': 0,
 'finger_1_joint_2': 0,
 'finger_1_joint_1': 0.0,
 'finger_2_joint_1': 0.0,
 'finger_middle_joint_1': 0.0}


class Kuka(RobotInterface):
    def __init__(self, iiwa_path="./iiwa/kuka.urdf"):
        super().__init__()
        txt = open(iiwa_path).read()
        iiwa_chain = kp.build_chain_from_urdf(txt)
        self.ee_link_index = robot_joints.index('palm') - 1
        logger.info(f"ee_link_index: {self.ee_link_index}, ee_link name: {robot_connections[self.ee_link_index]}")
        self.gripper_rot6d_map = {}
        for deg in trange(0, 101):
            joint_values = {**default_joint_values, **map_gripper_ctrl_to_joint_v(deg)}
            link_poses = iiwa_chain.forward_kinematics(joint_values)

            robot_joints_values = {}
            for rj_name, parents in robot_joints_from_origin.items(): 
                if isinstance(parents, list):
                    values = []
                    for p in parents:
                        values.append(link_poses[p].pos)
                    robot_joints_values[rj_name] = sum(values) / len(values)
                else:
                    robot_joints_values[rj_name] = link_poses[parents].pos

            joint_positions = []
            for rj_name in robot_joints:
                joint_positions.append(robot_joints_values[rj_name])
            joint_positions = torch.from_numpy(np.array(joint_positions)).float()
            anim_chain = anim.inverse_kinematic(joint_positions, robot_connections_int)
            rot6d = anim.retrieve_tensor_from_chain(anim_chain, 'rot6d')
            self.gripper_rot6d_map[deg] = rot6d[self.ee_link_index+1:].cuda()
        
        self.buf_rot6d = None

    def initialize(self, motion_module: MotionBlender) -> None:
        motion_module.compute_link_pose_at_t(0)
        ee_pose = motion_module._links_tensor_cache[0][self.ee_link_index] # this is global pose
        ee_pose_robot_base = torch.inverse(motion_module._global_T_cache[0]) @ ee_pose
        self.buf['pose'] = ee_pose_robot_base
        self.buf['stale'] = False
        self.buf['rot6d'] = motion_module.rot6d[0].clone()
        self.buf['length'] = torch.exp(motion_module.length.clone())
        self.hollow_chain = motion_module.hollow_chain
        self.length_linkid2indice = motion_module.length_linkid2indice
        self.rot6d_linkid2indice = motion_module.rot6d_linkid2indice
        self.inited = True
    
    def get_joint_rotations(self, refresh=False) -> Float32[Tensor, "j 6"]:
        if not self.buf['stale'] and self.buf['rot6d'] is not None and not refresh:
            return self.buf['rot6d']
        ee_pose = self.buf['pose']
        ee_degree = self.buf['degree']
        gripper_rot6d = self.gripper_rot6d_map[ee_degree].detach()
        length = self.buf['length'].detach()
        ee_rot6d = anim.rmat_to_cont_6d(ee_pose[:3, :3])

        with torch.enable_grad():
            rot = nn.Parameter(self.buf['rot6d'][:robot_joints.index('palm')].clone())

            opt = optim.Adam([rot], lr=0.01)            
            for _ in range(300):
                prev_rot = rot.data.clone()
                param_chain = anim.fill_hollow_chain_with_tensor(self.hollow_chain, length, torch.cat([rot, gripper_rot6d]), 
                                                                 self.rot6d_linkid2indice, self.length_linkid2indice)
                link_poses = anim.forward_kinematic(param_chain)

                pred_ee_pose = link_poses[self.ee_link_index]
                xyz_loss = F.mse_loss(pred_ee_pose[:3, 3], ee_pose[:3, 3])

                pred_ee_rot6d = anim.rmat_to_cont_6d(pred_ee_pose[:3, :3])
                rot_loss = F.mse_loss(pred_ee_rot6d, ee_rot6d)

                loss = xyz_loss + rot_loss
                if loss < 1e-5:
                    logger.info(f"[Inverse Kinematic] converged in {_} iterations")
                    break
                opt.zero_grad()
                loss.backward()
                opt.step()
                if (rot.data - prev_rot).norm() < 1e-3:
                    logger.info(f"[Inverse Kinematics] converged in {_} iterations")
                    break
            logger.info(f"[Inverse Kinematic]: xyz {xyz_loss.item()}, rot {rot_loss.item()}")

        arm_rot6d = rot.detach()
        self.buf['stale'] = False
        self.buf['rot6d'] = torch.cat([arm_rot6d, gripper_rot6d])
        return self.buf['rot6d']
    

robot = Kuka()

if __name__ == "__main__":
    gs_modules, motion_modules, _, gaussian_names = misc.load_cpkl("outputs/mb/robot/okish/toy/ckpt.robot.cpkl")
    robot.initialize(motion_modules['robot'])
    robot.buf['pose'][:3, 3] += (torch.rand(3).cuda() * 0.3)
    robot.buf['pose'][:3, :3] = roma.random_rotmat().cuda() @ robot.buf['pose'][:3, :3]
    print(robot.get_joint_rotations(refresh=True))