import torch
import math

def get_rodrigues_matrix(axis, angle):
    '''
    axis = [1,0,0] / [0, 1, 0] / [0, 0, 1]
    angle = [batch_size, 1]
    return -> [batch_size, 3, 3]
    '''
    axis = torch.tensor(axis, dtype=angle.dtype, device=angle.device)
    identity = torch.eye(3).unsqueeze(0).repeat(angle.shape[0], 1, 1).to(angle.device)      # identity -> [batch_size, 3, 3]
    s1 = torch.tensor(
        [
            [0, -axis[2], axis[1]],
            [axis[2], 0, -axis[0]],
            [-axis[1], axis[0], 0],
        ],
        dtype=angle.dtype,
        device=angle.device
    ).unsqueeze(0).repeat(angle.shape[0], 1, 1)                             # s1 -> [batch_size, 3, 3]
    axis = axis.view(3, 1)  # Ensure axis has shape [3, 1]
    s2 = torch.matmul(axis, axis.T).unsqueeze(0).repeat(angle.shape[0], 1, 1)      #s2 -> [batch_size, 3, 3]
    cos_angle = torch.cos(angle).unsqueeze(1).unsqueeze(1)
    sin_angle = torch.sin(angle).unsqueeze(1).unsqueeze(1)

    #print ("cos_angle:", cos_angle.shape)
    rodrigues_matrix = cos_angle * identity + sin_angle * s1 + (1 - cos_angle) * s2
    return rodrigues_matrix

def apply_transformation(vertices, position, rotation, rotation_order="XYZ", offset_first=False):
    '''
    vertices -> [batch_size * 3, 8, 3]
    position -> [batch_size * 3, 3]
    rotation -> [batch_size * 3, 3]
    '''

    position = position.unsqueeze(1).repeat(1, 8, 1)

    # process position first
    if offset_first:
        vertices = vertices + position

    # process rotation
    rot_mat = {}

    rot_mat["X"] = get_rodrigues_matrix([1, 0, 0], rotation[:,0])
    rot_mat["Y"] = get_rodrigues_matrix([0, 1, 0], rotation[:,1])
    rot_mat["Z"] = get_rodrigues_matrix([0, 0, 1], rotation[:,2])

    for s in rotation_order:
        vertices = torch.matmul(vertices, rot_mat[s].transpose(1, 2))

    # process position second
    if not offset_first:
        vertices = vertices + position

    return vertices

def Drawer(para):
    B = para.shape[0]
    device = para.device
    dtype = para.dtype

    size = para[:,0:3]
    position = para[:,3:6]
    rotation = para[:,6:9]* math.pi 

    handle_offset_local = torch.stack([
        torch.zeros(B, dtype=dtype, device=device),  # x
        torch.zeros(B, dtype=dtype, device=device),  # y
        size[:, 2] / 2                        # z
    ], dim=1)  # [B, 3]

    handle_offset_local_exp = handle_offset_local.unsqueeze(1)
    handle_offset_global = apply_transformation(handle_offset_local_exp, position, rotation, rotation_order="XYZ", offset_first=False)
    gripper_pos = handle_offset_global[:, 0, :]  # [B, 3]

    #gripper_rot = rotation + torch.tensor([math.pi, 0.0, 0.0], dtype=dtype, device=device) 

    #gripper_pose = torch.cat([gripper_pos, gripper_rot], dim=1)

    return gripper_pos