import torch
import math

def get_rodrigues_matrix(axis, angle):
    '''
    axis = [1,0,0] / [0, 1, 0] / [0, 0, 1]
    angle = [batch_size, 1]
    return -> [batch_size, 3, 3]
    '''
    axis = torch.tensor(axis, dtype=angle.dtype, device=angle.device)
    identity = torch.eye(3).unsqueeze(0).repeat(angle.shape[0], 1, 1).to(angle.device)      # identity -> [batch_size, 3, 3]
    s1 = torch.tensor(
        [
            [0, -axis[2], axis[1]],
            [axis[2], 0, -axis[0]],
            [-axis[1], axis[0], 0],
        ],
        dtype=angle.dtype,
        device=angle.device
    ).unsqueeze(0).repeat(angle.shape[0], 1, 1)                             # s1 -> [batch_size, 3, 3]
    axis = axis.view(3, 1)  # Ensure axis has shape [3, 1]
    s2 = torch.matmul(axis, axis.T).unsqueeze(0).repeat(angle.shape[0], 1, 1)      #s2 -> [batch_size, 3, 3]
    cos_angle = torch.cos(angle).unsqueeze(1).unsqueeze(1)
    sin_angle = torch.sin(angle).unsqueeze(1).unsqueeze(1)

    #print ("cos_angle:", cos_angle.shape)
    rodrigues_matrix = cos_angle * identity + sin_angle * s1 + (1 - cos_angle) * s2
    return rodrigues_matrix

def apply_transformation(vertices, position, rotation, rotation_order="XYZ", offset_first=False):
    '''
    vertices -> [batch_size * 3, 8, 3]
    position -> [batch_size * 3, 3]
    rotation -> [batch_size * 3, 3]
    '''

    position = position.unsqueeze(1).repeat(1, 8, 1)

    # process position first
    if offset_first:
        vertices = vertices + position

    # process rotation
    rot_mat = {}

    rot_mat["X"] = get_rodrigues_matrix([1, 0, 0], rotation[:,0])
    rot_mat["Y"] = get_rodrigues_matrix([0, 1, 0], rotation[:,1])
    rot_mat["Z"] = get_rodrigues_matrix([0, 0, 1], rotation[:,2])

    for s in rotation_order:
        vertices = torch.matmul(vertices, rot_mat[s].transpose(1, 2))

    # process position second
    if not offset_first:
        vertices = vertices + position

    return vertices

def transformation_matrix(position, rotation):
    '''
    Get RT
    position -> [batch_size * 3, 3]
    rotation -> [batch-size * 3, 3]
    '''
    Rx, Ry, Rz = rotation[:, 0], rotation[:, 1], rotation[:, 2]

    # calculate Rotation matrix
    ones = torch.ones (rotation.shape[0], device=rotation.device)
    zeros = torch.zeros(rotation.shape[0], device=rotation.device)
    Rx_matrix = torch.stack ([
        torch.stack([ones, zeros, zeros], dim=-1),
        torch.stack([zeros, torch.cos(Rx), -torch.sin(Rx)], dim=-1),
        torch.stack([zeros, torch.sin(Rx), torch.cos(Rx)], dim=-1)
    ], dim=-2)
    Ry_matrix = torch.stack([
        torch.stack([torch.cos(Ry), zeros, torch.sin(Ry)], dim=-1),
        torch.stack([zeros, ones, zeros], dim=-1),
        torch.stack([-torch.sin(Ry), zeros, torch.cos(Ry)], dim=-1)
    ], dim=-2)
    Rz_matrix = torch.stack([
        torch.stack([torch.cos(Rz), -torch.sin(Rz), zeros], dim=-1),
        torch.stack([torch.sin(Rz), torch.cos(Rz), zeros], dim=-1),
        torch.stack([zeros, zeros, ones], dim=-1)
    ], dim=-2)
    rotation_matrix = Rz_matrix @ Ry_matrix @ Rx_matrix
    
    # concatenate position
    RT = torch.eye(4, device=rotation.device).unsqueeze(0).repeat(rotation.shape[0], 1, 1)
    RT[:, :3, :3] = rotation_matrix
    RT[:, :3, -1] = position
    return RT

def rotation_matrix_to_euler_angles(R):
    sy = torch.sqrt(R[:,0,0]**2 + R[:,1,0]**2)
    singular = sy < 1e-6

    x = torch.atan2(R[:,2,1], R[:,2,2])
    y = torch.atan2(-R[:,2,0], sy)
    z = torch.atan2(R[:,1,0], R[:,0,0])

    euler = torch.stack([x,y,z], dim=1)
    return euler

def Slide (para):
    B = para.shape[0]
    device = para.device
    dtype = para.dtype

    size = para[:,0:3]
    position = para[:,3:6]
    rotation = para[:,6:9]* math.pi 

    handle_offset_local = torch.stack([
        torch.zeros(B, dtype=dtype, device=device),  # x
        torch.zeros(B, dtype=dtype, device=device),  # y
        size[:, 2] / 2                        # z: 
    ], dim=1)  # [B, 3]

    handle_offset_local_exp = handle_offset_local.unsqueeze(1)
    handle_offset_global = apply_transformation(handle_offset_local_exp, position, rotation, rotation_order="XYZ", offset_first=False)
    gripper_pos = handle_offset_global[:, 0, :]  # [B, 3]

    handle_rot_local = torch.tensor([0.0, 0.0, math.pi / 2], dtype=dtype, device=device).unsqueeze(0).repeat(B, 1)
    R_handle = transformation_matrix(position, rotation)[:, :3, :3] 
    R_local = transformation_matrix(torch.zeros_like(position), handle_rot_local)[:, :3, :3]
    gripper_rot = torch.bmm(R_handle, R_local)

    gripper_pose = torch.cat([gripper_pos, rotation_matrix_to_euler_angles(gripper_rot)], dim=1)

    return gripper_pose