import torch
import numpy as np

def catesian_to_spherical(xyz):
    r = torch.norm(xyz, dim=-1)
    theta = torch.asin(xyz[..., 2] / r)[:, None]
    phi = torch.atan2(xyz[..., 1], xyz[..., 0])[:, None]
    return theta, phi

def spherical_to_cartesian(ele, azi, r=4.0311):
    x = r * torch.cos(ele) * torch.cos(azi)
    y = r * torch.cos(ele) * torch.sin(azi)
    z = r * torch.sin(ele)
    return torch.cat([x, y, z], dim=-1)

def pose_to_matrix(pose, rot_only=True):
    zz = pose / pose.norm(dim=-1, keepdim=True)
    xx = torch.tensor([0.0, 0.0, 1.0], device=zz.device)[None].cross(zz)
    xx = xx / xx.norm(dim=-1, keepdim=True)
    yy = zz.cross(xx)
    yy = yy / yy.norm(dim=-1, keepdim=True)
    rot = torch.stack([xx, yy, zz], dim=-1)
    shape = rot.shape
    result = torch.zeros((*shape[:-2], 4, 4), device=pose.device)
    result[..., :3, :3] = rot
    if not rot_only:
        result[..., :3, 3] = pose
    result[..., 3, 3] = 1
    return result


def ang_to_matrix(ang, rot_only=True, r=4.0311):
    ele, azi = ang.split([1, 1], dim=-1)
    coords = spherical_to_cartesian(ele, azi, r)
    if coords.dim() == 3:
        bs_bin = coords.shape[:2]
        coords_std = coords.reshape(-1, 3)
    else:
        coords_std = coords

    matrix = pose_to_matrix(coords_std, rot_only=rot_only)

    if coords.dim() == 3:
        matrix = matrix.reshape(*bs_bin, 4, 4)
    return matrix


def matrix_to_ang(matrix):
    ele = torch.asin(matrix[..., 2, 2])
    azi = torch.atan2(matrix[..., 1, 2], matrix[..., 0, 2])
    assert torch.all(torch.isnan(ele) == False) and torch.all(
        torch.isnan(azi) == False
    ), f"ele: {ele}, azi: {azi}, matrix: {matrix}"
    return torch.stack([ele, azi], dim=-1) % (2 * np.pi)
