import torch
from pytorch3d import transforms
from utils.quaternion import qbetween, qbetween_safe, qrot


def process_motion(
        motion_dict,         # {'joints': [B, T, J_all, 3], 'body_pose': [B, T, J-1, 6]}
        feet_thre,           # float
        prev_frames,         # int
        n_joints,            # int, number of joints to use (<= motion_dict['joints'].shape[2])
        device,
    ):
    face_joint_indx = [2,1,17,16]
    fid_l = [7,10]
    fid_r = [8,11]
    
    B, T, *_ = motion_dict['joints'].shape
    joints = motion_dict['joints'].reshape(B, T, n_joints, 3).to(device)   # [B, T, J, 3]
    rotations = motion_dict['body_pose'].reshape(B, T, -1).to(device)                        # [B, T, J-1, 6]

    trans_matrix = torch.tensor([
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, -1.0, 0.0]
    ], dtype=joints.dtype, device=device)

    joints = joints @ trans_matrix.T  # [B, T, J, 3]

    floor_height = joints[..., 1].amin(dim=(1,2), keepdim=True)  # [B,1,1,1]
    joints[..., 1] -= floor_height

    root_init = joints[:, prev_frames, 0, :]  # [B, 3]
    root_init_xz = root_init * torch.tensor([1, 0, 1], dtype=joints.dtype, device=device)  # [B, 3]
    joints -= root_init_xz[:, None, None, :]

    r_hip, l_hip, *_ = face_joint_indx
    across = joints[:, prev_frames, r_hip] - joints[:, prev_frames, l_hip]  # [B, 3]
    across = across / across.norm(dim=-1, keepdim=True)

    up = torch.tensor([0, 1, 0], dtype=joints.dtype, device=device).expand(B, 3)
    forward = torch.cross(up, across, dim=-1)
    forward = forward / forward.norm(dim=-1, keepdim=True)
    target = torch.tensor([0, 0, 1], dtype=joints.dtype, device=device).expand(B, 3)

    quat = qbetween(forward, target)  # [B, 4]
    root_quat_init_for_all = quat[:, None, None, :].expand(-1, T, n_joints, 4)

    
    joints = qrot(root_quat_init_for_all, joints)  # [B, T, J, 3]

    def detect_feet(joints, fid):
        vel = (joints[:, 1:, fid] - joints[:, :-1, fid]) ** 2  # [B, T-1, 3]
        vel_sum = vel.sum(-1)  # [B, T-1]
        height = joints[:, :-1, fid, 1]  # [B, T-1]
        contact = ((vel_sum < feet_thre) & (height < 0.05)).float()
        return contact  # [B, T-1, 1]

    feet_l = detect_feet(joints, fid_l)
    feet_r = detect_feet(joints, fid_r)

    pos_flat = joints[:, :-1].reshape(B, T - 1, -1)  # [B, T-1, J*3]
    vel_flat = (joints[:, 1:] - joints[:, :-1]).reshape(B, T - 1, -1)
    rot_flat = rotations[:, :-1]  # [B, T-1, R]
    data = torch.cat([pos_flat, vel_flat, rot_flat, feet_l, feet_r], dim=-1)  # [B, T-1, D']

    return data, quat, root_init_xz  # [B, T-1, D'], [B, 4], [B, 3]


def process_motion_refined(
        motion_dict,         # {'joints': [B, T, J_all, 3], 'body_pose': [B, T, J-1, 6]}
        feet_thre,           # float
        prev_frames,         # int
        n_joints,            # int, number of joints to use (<= motion_dict['joints'].shape[2])
        device,
    ):
    face_joint_indx = [2,1,17,16]
    fid_l = [7,10]
    fid_r = [8,11]
    
    B, T, *_ = motion_dict['joints'].shape
    joints = motion_dict['joints'].reshape(B, T, n_joints, 3).to(device)                     # [B, T, J, 3]
    rotations = motion_dict['body_pose'].reshape(B, T, -1).to(device)                        # [B, T, J-1, 6]
    
    # put on the floor and move to the origin
    # floor_height = joints[..., 2].amin(dim=(1,2))  # [B,1,1,1]
    # joints[..., 2] -= floor_height

    root_init = joints[:, prev_frames, 0, :]  # [B, 3]
    root_init = root_init * torch.tensor([1, 1, 0], dtype=joints.dtype, device=device)  # [B, 3]
    # root_init[..., 2] = floor_height
    joints -= root_init[:, None, None, :]
    
    trans_matrix = torch.tensor([
        [1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0],
        [0.0, -1.0, 0.0]
    ], dtype=joints.dtype, device=device)

    joints = joints @ trans_matrix.T  # [B, T, J, 3]

    r_hip, l_hip, *_ = face_joint_indx
    across = joints[:, prev_frames, r_hip] - joints[:, prev_frames, l_hip]  # [B, 3]
    across = across / across.norm(dim=-1, keepdim=True)

    up = torch.tensor([0, 1, 0], dtype=joints.dtype, device=device).expand(B, 3)
    forward = torch.cross(up, across, dim=-1)
    forward = forward / forward.norm(dim=-1, keepdim=True)
    target = torch.tensor([0, 0, 1], dtype=joints.dtype, device=device).expand(B, 3)

    quat = qbetween(forward, target)  # [B, 4]
    root_quat_init_for_all = quat[:, None, None, :].expand(-1, T, n_joints, 4)

    joints = qrot(root_quat_init_for_all, joints)  # [B, T, J, 3]

    def detect_feet(joints, fid):
        vel = (joints[:, 1:, fid] - joints[:, :-1, fid]) ** 2  # [B, T-1, 3]
        vel_sum = vel.sum(-1)  # [B, T-1]
        height = joints[:, :-1, fid, 1]  # [B, T-1]
        contact = ((vel_sum < feet_thre) & (height < 0.05)).float()
        return contact  # [B, T-1, 1]

    feet_l = detect_feet(joints, fid_l)
    feet_r = detect_feet(joints, fid_r)

    pos_flat = joints[:, :-1].reshape(B, T - 1, -1)  # [B, T-1, J*3]
    vel_flat = (joints[:, 1:] - joints[:, :-1]).reshape(B, T - 1, -1)
    rot_flat = rotations[:, :-1]  # [B, T-1, R]
    data = torch.cat([pos_flat, vel_flat, rot_flat, feet_l, feet_r], dim=-1)  # [B, T-1, D']
    
    quat_matrix = transforms.quaternion_to_matrix(quat)  # [B, 4] -> [B, 3, 3]
    t_matrix = quat_matrix @ trans_matrix

    return data, quat, root_init, t_matrix  # [B, T-1, D'], [B, 4], [B, 3]


def process_motion_blended(
        motion_dict,         # {'joints': [B, T, J_all, 3], 'body_pose': [B, T, J-1, 6]}
        feet_thre,           # float
        prev_frames,         # int
        n_joints,            # int, number of joints to use (<= motion_dict['joints'].shape[2])
        device,
    ):
    face_joint_indx = [2,1,17,16]
    
    B, T, *_ = motion_dict['joints'].shape
    joints = motion_dict['joints'].reshape(B, T, n_joints, 3).to(device)                        # [B, T, J, 3]
    joints_delta = motion_dict['joints_delta'].reshape(B, T, n_joints, 3).to(device)            # [B, T, J, 3]
    rotations = motion_dict['body_pose'].reshape(B, T, -1).to(device)                           # [B, T, J-1, 6]
    feet_contact = motion_dict['feet_contact'].reshape(B, T, -1).to(device)                     # [B, T-1, 2]
    
    
    # floor_height = joints[..., 1].amin(dim=(1,2))  # [B,1,1,1]

    # move to the origin
    root_init = joints[:, prev_frames, 0, :]  # [B, 3]
    root_init = root_init * torch.tensor([1, 0, 1], dtype=joints.dtype, device=device)  # [B, 3]
    # root_init[..., 1] = floor_height
    joints -= root_init[:, None, None, :]

    r_hip, l_hip, *_ = face_joint_indx
    across = joints[:, prev_frames, r_hip] - joints[:, prev_frames, l_hip]  # [B, 3]
    across = across / across.norm(dim=-1, keepdim=True)

    up = torch.tensor([0, 1, 0], dtype=joints.dtype, device=device).expand(B, 3)
    forward = torch.cross(up, across, dim=-1)
    forward = forward / forward.norm(dim=-1, keepdim=True)
    target = torch.tensor([0, 0, 1], dtype=joints.dtype, device=device).expand(B, 3)

    quat = qbetween(forward, target)  # [B, 4]
    root_quat_init_for_all = quat[:, None, None, :].expand(-1, T, n_joints, 4)

    joints = qrot(root_quat_init_for_all, joints)  # [B, T, J, 3]
    joints_delta = qrot(root_quat_init_for_all, joints_delta)  # [B, T, J, 3]

    pos_flat = joints.reshape(B, T, -1)  # [B, T, J*3]
    vel_flat = joints_delta.reshape(B, T, -1)
    rot_flat = rotations  # [B, T, R]
    data = torch.cat([pos_flat, vel_flat, rot_flat, feet_contact], dim=-1)  # [B, T-1, D']
    
    quat_matrix = transforms.quaternion_to_matrix(quat)  # [B, 4] -> [B, 3, 3]
    

    return data, quat, root_init, quat_matrix  # [B, T-1, D'], [B, 4], [B, 3]


def cal_rel_rot(jts1, jts2):
    face_joint_indx = [2,1,17,16]
    B, T, *_ = jts1.shape
    
    r_hip, l_hip, *_ = face_joint_indx
    across1 = jts1[:, :, r_hip] - jts1[:, :, l_hip]  # [B, T, 3]
    across1 = across1 / across1.norm(dim=-1, keepdim=True)
    across2 = jts2[:, :, r_hip] - jts2[:, :, l_hip]  # [B, T, 3]
    across2 = across2 / across2.norm(dim=-1, keepdim=True)

    up = torch.tensor([0, 1, 0], dtype=jts1.dtype, device=jts1.device).expand(B, T, 3)
    forward1 = torch.cross(up, across1, dim=-1)
    forward1 = forward1 / forward1.norm(dim=-1, keepdim=True)
    forward2 = torch.cross(up, across2, dim=-1)
    forward2 = forward2 / forward2.norm(dim=-1, keepdim=True)

    quat_b2a = qbetween_safe(forward2, forward1)  # [B, T, 4]
    quat_a2b = qbetween_safe(forward1, forward2)  # [B, T, 4]
    
    quat_matrix_b2a = transforms.quaternion_to_matrix(quat_b2a)  # [B, T, 4] -> [B, T, 3, 3]
    quat_matrix_a2b = transforms.quaternion_to_matrix(quat_a2b)  # [B, T, 4] -> [B, T, 3, 3]
    
    return quat_matrix_b2a, quat_matrix_a2b