"""
Transformation utils
"""

import numpy as np
import torch

from opencood.utils.common_utils import check_numpy_to_torch


def x_to_world(pose):
    """
    The transformation matrix from x-coordinate system to carla world system

    Parameters
    ----------
    pose : list
        [x, y, z, roll, yaw, pitch]

    Returns
    -------
    matrix : np.ndarray
        The transformation matrix.
    """
    x, y, z, roll, yaw, pitch = pose[:]

    # used for rotation matrix
    c_y = np.cos(np.radians(yaw))
    s_y = np.sin(np.radians(yaw))
    c_r = np.cos(np.radians(roll))
    s_r = np.sin(np.radians(roll))
    c_p = np.cos(np.radians(pitch))
    s_p = np.sin(np.radians(pitch))

    matrix = np.identity(4)
    # translation matrix
    matrix[0, 3] = x
    matrix[1, 3] = y
    matrix[2, 3] = z

    # rotation matrix
    matrix[0, 0] = c_p * c_y
    matrix[0, 1] = c_y * s_p * s_r - s_y * c_r
    matrix[0, 2] = -c_y * s_p * c_r - s_y * s_r
    matrix[1, 0] = s_y * c_p
    matrix[1, 1] = s_y * s_p * s_r + c_y * c_r
    matrix[1, 2] = -s_y * s_p * c_r + c_y * s_r
    matrix[2, 0] = s_p
    matrix[2, 1] = -c_p * s_r
    matrix[2, 2] = c_p * c_r

    return matrix


def x1_to_x2(x1, x2):
    """
    Transformation matrix from x1 to x2.

    Parameters
    ----------
    x1 : list
        The pose of x1 under world coordinates.
    x2 : list
        The pose of x2 under world coordinates.

    Returns
    -------
    transformation_matrix : np.ndarray
        The transformation matrix.

    """
    x1_to_world = x_to_world(x1)
    x2_to_world = x_to_world(x2)
    world_to_x2 = np.linalg.inv(x2_to_world)

    transformation_matrix = np.dot(world_to_x2, x1_to_world)
    return transformation_matrix


def dist_to_continuous(p_dist, displacement_dist, res, downsample_rate):
    """
    Convert points discretized format to continuous space for BEV representation.
    Parameters
    ----------
    p_dist : numpy.array
        Points in discretized coorindates.

    displacement_dist : numpy.array
        Discretized coordinates of bottom left origin.

    res : float
        Discretization resolution.

    downsample_rate : int
        Dowmsamping rate.

    Returns
    -------
    p_continuous : numpy.array
        Points in continuous coorindates.

    """
    p_dist = np.copy(p_dist)
    p_dist = p_dist + displacement_dist
    p_continuous = p_dist * res * downsample_rate
    return p_continuous


def muilt_coord(rotationA2B, translationA2B, rotationB2C, translationB2C):
    rotationA2B = np.array(rotationA2B).reshape(3, 3)
    rotationB2C = np.array(rotationB2C).reshape(3, 3)
    rotation = np.dot(rotationB2C, rotationA2B)
    translationA2B = np.array(translationA2B).reshape(3, 1)
    translationB2C = np.array(translationB2C).reshape(3, 1)
    translation = np.dot(rotationB2C, translationA2B) + translationB2C

    return rotation, translation


def pose_to_tfm(pose):
    """ Transform batch of pose to tfm
    Args:
        pose: torch.Tensor or np.ndarray
            [N, 3], x, y, yaw, in degree
            [N, 6], x, y, z, roll, yaw, pitch, in degree

            roll and pitch follows carla coordinate
    Returns:
        tfm: torch.Tensor
            [N, 4, 4] 
    """

    pose_tensor, is_np = check_numpy_to_torch(pose)
    pose = pose_tensor


    if pose.shape[1] == 3:
        N = pose.shape[0]
        x = pose[:,0]
        y = pose[:,1]
        yaw = pose[:,2]

        tfm = torch.eye(4, device=pose.device).view(1,4,4).repeat(N,1,1)
        tfm[:,0,0] = torch.cos(torch.deg2rad(yaw))
        tfm[:,0,1] = - torch.sin(torch.deg2rad(yaw))
        tfm[:,1,0] = torch.sin(torch.deg2rad(yaw))
        tfm[:,1,1] = torch.cos(torch.deg2rad(yaw))
        tfm[:,0,3] = x
        tfm[:,1,3] = y

    elif pose.shape[1] == 6:
        N = pose.shape[0]
        x = pose[:,0]
        y = pose[:,1]
        z = pose[:,2]
        roll = pose[:,3]
        yaw = pose[:,4]
        pitch = pose[:,5]

        c_y = torch.cos(torch.deg2rad(yaw))
        s_y = torch.sin(torch.deg2rad(yaw))
        c_r = torch.cos(torch.deg2rad(roll))
        s_r = torch.sin(torch.deg2rad(roll))
        c_p = torch.cos(torch.deg2rad(pitch))
        s_p = torch.sin(torch.deg2rad(pitch))

        tfm = torch.eye(4, device=pose.device).view(1,4,4).repeat(N,1,1)

        # translation matrix
        tfm[:, 0, 3] = x
        tfm[:, 1, 3] = y
        tfm[:, 2, 3] = z

        # rotation matrix
        tfm[:, 0, 0] = c_p * c_y
        tfm[:, 0, 1] = c_y * s_p * s_r - s_y * c_r
        tfm[:, 0, 2] = -c_y * s_p * c_r - s_y * s_r
        tfm[:, 1, 0] = s_y * c_p
        tfm[:, 1, 1] = s_y * s_p * s_r + c_y * c_r
        tfm[:, 1, 2] = -s_y * s_p * c_r + c_y * s_r
        tfm[:, 2, 0] = s_p
        tfm[:, 2, 1] = -c_p * s_r
        tfm[:, 2, 2] = c_p * c_r

    if is_np:
        tfm = tfm.numpy()

    return tfm


def tfm_to_pose(tfm: np.ndarray):
    """
    turn transformation matrix to [x, y, z, roll, yaw, pitch]
    we use radians format.
    tfm is pose in transformation format, and XYZ order, i.e. roll-pitch-yaw
    """
    # These formulas are designed from x_to_world, but equal to the one below.
    yaw = np.degrees(np.arctan2(tfm[1, 0], tfm[0, 0])) # clockwise in carla
    roll = np.degrees(np.arctan2(-tfm[2, 1], tfm[2, 2])) # but counter-clockwise in carla
    pitch = np.degrees(np.arctan2(tfm[2, 0], ((tfm[2, 1] ** 2 + tfm[2, 2] ** 2) ** 0.5)) ) # but counter-clockwise in carla

    # These formulas are designed for consistent axis orientation
    # yaw = np.degrees(np.arctan2(tfm[1,0], tfm[0,0])) # clockwise in carla
    # roll = np.degrees(np.arctan2(tfm[2,1], tfm[2,2])) # but counter-clockwise in carla
    # pitch = np.degrees(np.arctan2(-tfm[2,0], ((tfm[2,1]**2 + tfm[2,2]**2) ** 0.5)) ) # but counter-clockwise in carla

    # roll = - roll
    # pitch = - pitch

    x, y, z = tfm[:3,3]
    return [x, y, z, roll, yaw, pitch]