import torch
import torch.functional as F

def quaternion_rotation(rotation, x, right = False, transpose = False):
    """rotate a batch of quaterions with by orthogonal rotation matrices

    Args:
        rotation (torch.Tensor): parameters that used to rotate other vectors [batch_size, rank]
        x (torch.Tensor): vectors that need to be rotated [batch_size, rank]
        right (bool): whether to rotate left or right
        transpose (bool): whether to transpose the rotation matrix

    Returns:
        rotated_vectors(torch.Tensor): rotated_results [batch_size, rank]
    """
    # ! it seems like calculate in this way will slow down the speed
    # turn to quaterion
    rotation = rotation.view(rotation.shape[0], -1, 4)
    rotation = rotation / torch.norm(rotation, dim=-1, keepdim=True).clamp_min(1e-15)  # unify each quaterion of the rotation part
    p, q, r, s = rotation[:, :, 0], rotation[:, :, 1], rotation[:, :, 2], rotation[:, :, 3]
    if transpose:
        q, r, s = -q, -r ,-s # the transpose of this quaterion rotation matrix can be achieved by negating the q, r, s
    s_a, x_a, y_a, z_a = x.chunk(dim=1, chunks=4)
    
    if right:
        # right rotation
        # the original version used in QuatE is actually the right rotation
        rotated_s = s_a * p - x_a * q - y_a * r - z_a * s
        rotated_x = s_a * q + p * x_a + y_a * s - r * z_a
        rotated_y = s_a * r + p * y_a + z_a * q - s * x_a
        rotated_z = s_a * s + p * z_a + x_a * r - q * y_a
    else:
        # left rotation
        rotated_s = s_a * p - x_a * q - y_a * r - z_a * s
        rotated_x = s_a * q + x_a * p - y_a * s + z_a * r
        rotated_y = s_a * r + x_a * s + y_a * p - z_a * q 
        rotated_z = s_a * s - x_a * r + y_a * q + z_a * p 

    rotated_vectors = torch.cat([rotated_s, rotated_x, rotated_y, rotated_z], dim=-1)

    return rotated_vectors

def quaternion_rotation_v2(rotation, ori):
    # use quaterion to calculate the rotation matrix, rather than transforme the quaternion
    
    # get the quaternion from the r
    ori_x, ori_y, ori_z = ori.chunk(dim=1, chunks=3)
    x, y, z = rotation.chunk(dim=1, chunks=3)
    w = torch.sqrt(torch.max(1 - x ** 2 - y ** 2 - z ** 2, torch.zeros_like(x)).clamp_min(1e-15)) # avoid w to be imaginary
    # ! this needed to be further modified, since parameters can not be negative
     
    # construct the rotation matrix
    matrix_11 = 1 - 2 * y ** 2 - 2 * z ** 2
    matrix_12 = 2 * x * y - 2 * z * w
    matrix_13 = 2 * x * z + 2 * y * w
    matrix_21 = 2 * x * y + 2 * z * w
    matrix_22 = 1 - 2 * x ** 2 - 2 * z ** 2
    matrix_23 = 2 * y * z - 2 * x * w
    matrix_31 = 2 * x * z - 2 * y * w
    matrix_32 = 2 * y * z + 2 * x * w
    matrix_33 = 1 - 2 * x ** 2 - 2 * y ** 2

    rotated_x = matrix_11 * ori_x + matrix_12 * ori_y + matrix_13 * ori_z
    rotated_y = matrix_21 * ori_x + matrix_22 * ori_y + matrix_23 * ori_z
    rotated_z = matrix_31 * ori_x + matrix_32 * ori_y + matrix_33 * ori_z

    rotated_vector = torch.cat([rotated_x, rotated_y, rotated_z], dim=-1)
    
    return rotated_vector

    