from typing import *
from numbers import Number 

import torch
import torch.nn.functional as F

from ._helpers import batched


__all__ = [
    'perspective',
    'perspective_from_fov',
    'perspective_from_fov_xy',
    'intrinsics_from_focal_center',
    'intrinsics_from_fov',
    'intrinsics_from_fov_xy',
    'view_look_at',
    'extrinsics_look_at',
    'perspective_to_intrinsics',
    'intrinsics_to_perspective',
    'extrinsics_to_view',
    'view_to_extrinsics',
    'normalize_intrinsics',
    'crop_intrinsics',
    'pixel_to_uv',
    'pixel_to_ndc',
    'uv_to_pixel',
    'project_depth',
    'depth_buffer_to_linear',
    'project_gl',
    'project_cv',
    'unproject_gl',
    'unproject_cv',
    'skew_symmetric',
    'rotation_matrix_from_vectors',
    'euler_axis_angle_rotation',
    'euler_angles_to_matrix',
    'matrix_to_euler_angles',
    'matrix_to_quaternion',
    'quaternion_to_matrix',
    'matrix_to_axis_angle',
    'axis_angle_to_matrix',
    'axis_angle_to_quaternion',
    'quaternion_to_axis_angle',
    'slerp',
    'interpolate_extrinsics',
    'interpolate_view',
    'extrinsics_to_essential',
    'to4x4',
    'rotation_matrix_2d',
    'rotate_2d',
    'translate_2d',
    'scale_2d',
    'apply_2d',
]


@batched(0,0,0,0)
def perspective(
        fov_y: Union[float, torch.Tensor],
        aspect: Union[float, torch.Tensor],
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Get OpenGL perspective matrix

    Args:
        fov_y (float | torch.Tensor): field of view in y axis
        aspect (float | torch.Tensor): aspect ratio
        near (float | torch.Tensor): near plane to clip
        far (float | torch.Tensor): far plane to clip

    Returns:
        (torch.Tensor): [..., 4, 4] perspective matrix
    """
    N = fov_y.shape[0]
    ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device)
    ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect)
    ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2))
    ret[:, 2, 2] = (near + far) / (near - far)
    ret[:, 2, 3] = 2. * near * far / (near - far)
    ret[:, 3, 2] = -1.
    return ret


def perspective_from_fov(
        fov: Union[float, torch.Tensor],
        width: Union[int, torch.Tensor],
        height: Union[int, torch.Tensor],
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Get OpenGL perspective matrix from field of view in largest dimension

    Args:
        fov (float | torch.Tensor): field of view in largest dimension
        width (int | torch.Tensor): image width
        height (int | torch.Tensor): image height
        near (float | torch.Tensor): near plane to clip
        far (float | torch.Tensor): far plane to clip

    Returns:
        (torch.Tensor): [..., 4, 4] perspective matrix
    """
    fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height))
    aspect = width / height
    return perspective(fov_y, aspect, near, far)


def perspective_from_fov_xy(
        fov_x: Union[float, torch.Tensor],
        fov_y: Union[float, torch.Tensor],
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Get OpenGL perspective matrix from field of view in x and y axis

    Args:
        fov_x (float | torch.Tensor): field of view in x axis
        fov_y (float | torch.Tensor): field of view in y axis
        near (float | torch.Tensor): near plane to clip
        far (float | torch.Tensor): far plane to clip

    Returns:
        (torch.Tensor): [..., 4, 4] perspective matrix
    """
    aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2)
    return perspective(fov_y, aspect, near, far)


@batched(0,0,0,0)
def intrinsics_from_focal_center(
    fx: Union[float, torch.Tensor],
    fy: Union[float, torch.Tensor],
    cx: Union[float, torch.Tensor],
    cy: Union[float, torch.Tensor]
) -> torch.Tensor:
    """
    Get OpenCV intrinsics matrix

    Args:
        focal_x (float | torch.Tensor): focal length in x axis
        focal_y (float | torch.Tensor): focal length in y axis
        cx (float | torch.Tensor): principal point in x axis
        cy (float | torch.Tensor): principal point in y axis

    Returns:
        (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
    """
    N = fx.shape[0]
    ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
    zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device)
    ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3))
    return ret


@batched(0, 0, 0, 0, 0, 0)
def intrinsics_from_fov(
    fov_max: Union[float, torch.Tensor] = None,
    fov_min: Union[float, torch.Tensor] = None,
    fov_x: Union[float, torch.Tensor] = None,
    fov_y: Union[float, torch.Tensor] = None,
    width: Union[int, torch.Tensor] = None,
    height: Union[int, torch.Tensor] = None,
) -> torch.Tensor:
    """
    Get normalized OpenCV intrinsics matrix from given field of view.
    You can provide either fov_max, fov_min, fov_x or fov_y

    Args:
        width (int | torch.Tensor): image width
        height (int | torch.Tensor): image height
        fov_max (float | torch.Tensor): field of view in largest dimension
        fov_min (float | torch.Tensor): field of view in smallest dimension
        fov_x (float | torch.Tensor): field of view in x axis
        fov_y (float | torch.Tensor): field of view in y axis

    Returns:
        (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
    """
    if fov_max is not None:
        fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2))
        fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2))
    elif fov_min is not None:
        fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2))
        fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2))
    elif fov_x is not None and fov_y is not None:
        fx = 1 / (2 * torch.tan(fov_x / 2))
        fy = 1 / (2 * torch.tan(fov_y / 2))
    elif fov_x is not None:
        fx = 1 / (2 * torch.tan(fov_x / 2))
        fy = fx * width / height
    elif fov_y is not None:
        fy = 1 / (2 * torch.tan(fov_y / 2))
        fx = fy * height / width
    cx = 0.5
    cy = 0.5
    ret = intrinsics_from_focal_center(fx, fy, cx, cy)
    return ret



def intrinsics_from_fov_xy(
        fov_x: Union[float, torch.Tensor],
        fov_y: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Get OpenCV intrinsics matrix from field of view in x and y axis

    Args:
        fov_x (float | torch.Tensor): field of view in x axis
        fov_y (float | torch.Tensor): field of view in y axis

    Returns:
        (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
    """
    focal_x = 0.5 / torch.tan(fov_x / 2)
    focal_y = 0.5 / torch.tan(fov_y / 2)
    cx = cy = 0.5
    return intrinsics_from_focal_center(focal_x, focal_y, cx, cy)


@batched(1,1,1)
def view_look_at(
        eye: torch.Tensor,
        look_at: torch.Tensor,
        up: torch.Tensor
    ) -> torch.Tensor:
    """
    Get OpenGL view matrix looking at something

    Args:
        eye (torch.Tensor): [..., 3] the eye position
        look_at (torch.Tensor): [..., 3] the position to look at
        up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction

    Returns:
        (torch.Tensor): [..., 4, 4], view matrix
    """
    N = eye.shape[0]
    z = eye - look_at
    x = torch.cross(up, z, dim=-1)
    y = torch.cross(z, x, dim=-1)
    # x = torch.cross(y, z, dim=-1)
    x = x / x.norm(dim=-1, keepdim=True)
    y = y / y.norm(dim=-1, keepdim=True)
    z = z / z.norm(dim=-1, keepdim=True)
    R = torch.stack([x, y, z], dim=-2)
    t = -torch.matmul(R, eye[..., None])
    ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
    ret[:, :3, :3] = R
    ret[:, :3, 3] = t[:, :, 0]
    ret[:, 3, 3] = 1.
    return ret


@batched(1, 1, 1)
def extrinsics_look_at(
    eye: torch.Tensor,
    look_at: torch.Tensor,
    up: torch.Tensor
) -> torch.Tensor:
    """
    Get OpenCV extrinsics matrix looking at something

    Args:
        eye (torch.Tensor): [..., 3] the eye position
        look_at (torch.Tensor): [..., 3] the position to look at
        up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction

    Returns:
        (torch.Tensor): [..., 4, 4], extrinsics matrix
    """
    N = eye.shape[0]
    z = look_at - eye
    x = torch.cross(-up, z, dim=-1)
    y = torch.cross(z, x, dim=-1)
    # x = torch.cross(y, z, dim=-1)
    x = x / x.norm(dim=-1, keepdim=True)
    y = y / y.norm(dim=-1, keepdim=True)
    z = z / z.norm(dim=-1, keepdim=True)
    R = torch.stack([x, y, z], dim=-2)
    t = -torch.matmul(R, eye[..., None])
    ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
    ret[:, :3, :3] = R
    ret[:, :3, 3] = t[:, :, 0]
    ret[:, 3, 3] = 1.
    return ret


@batched(2)
def perspective_to_intrinsics(
    perspective: torch.Tensor
) -> torch.Tensor:
    """
    OpenGL perspective matrix to OpenCV intrinsics

    Args:
        perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix

    Returns:
        (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics
    """
    assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix"
    ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \
        @ perspective[:, [0, 1, 3], :3] \
        @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device))
    return ret / ret[:, 2, 2, None, None]


@batched(2,0,0)
def intrinsics_to_perspective(
        intrinsics: torch.Tensor,
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor],
    ) -> torch.Tensor:
    """
    OpenCV intrinsics to OpenGL perspective matrix

    Args:
        intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
        near (float | torch.Tensor): [...] near plane to clip
        far (float | torch.Tensor): [...] far plane to clip
    Returns:
        (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
    """
    N = intrinsics.shape[0]
    fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
    cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
    ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
    ret[:, 0, 0] = 2 * fx
    ret[:, 1, 1] = 2 * fy
    ret[:, 0, 2] = -2 * cx + 1
    ret[:, 1, 2] = 2 * cy - 1
    ret[:, 2, 2] = (near + far) / (near - far)
    ret[:, 2, 3] = 2. * near * far / (near - far)
    ret[:, 3, 2] = -1.
    return ret


@batched(2)
def extrinsics_to_view(
        extrinsics: torch.Tensor
    ) -> torch.Tensor:
    """
    OpenCV camera extrinsics to OpenGL view matrix

    Args:
        extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix

    Returns:
        (torch.Tensor): [..., 4, 4] OpenGL view matrix
    """
    return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None]


@batched(2)
def view_to_extrinsics(
        view: torch.Tensor
    ) -> torch.Tensor:
    """
    OpenGL view matrix to OpenCV camera extrinsics

    Args:
        view (torch.Tensor): [..., 4, 4] OpenGL view matrix

    Returns:
        (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
    """
    return view  * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None]


@batched(2,0,0)
def normalize_intrinsics(
        intrinsics: torch.Tensor,
        width: Union[int, torch.Tensor],
        height: Union[int, torch.Tensor]
    ) -> torch.Tensor:
    """
    Normalize camera intrinsics(s) to uv space

    Args:
        intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
        width (int | torch.Tensor): [...] image width(s)
        height (int | torch.Tensor): [...] image height(s)

    Returns:
        (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)
    """
    zeros = torch.zeros_like(width)
    ones = torch.ones_like(width)
    transform = torch.stack([
        1 / width, zeros, 0.5 / width,
        zeros, 1 / height, 0.5 / height,
        zeros, zeros, ones
    ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
    return transform @ intrinsics



@batched(2,0,0,0,0,0,0)
def crop_intrinsics(
    intrinsics: torch.Tensor,
    width: Union[int, torch.Tensor],
    height: Union[int, torch.Tensor],
    left: Union[int, torch.Tensor],
    top: Union[int, torch.Tensor],
    crop_width: Union[int, torch.Tensor],
    crop_height: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]

    Args:
        intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
        width (int | torch.Tensor): [...] image width(s)
        height (int | torch.Tensor): [...] image height(s)
        left (int | torch.Tensor): [...] left crop boundary
        top (int | torch.Tensor): [...] top crop boundary
        crop_width (int | torch.Tensor): [...] crop width
        crop_height (int | torch.Tensor): [...] crop height

    Returns:
        (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)
    """
    zeros = torch.zeros_like(width)
    ones = torch.ones_like(width)
    transform = torch.stack([
        width / crop_width, zeros, -left / crop_width,
        zeros, height / crop_height, -top / crop_height,
        zeros, zeros, ones
    ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
    return transform @ intrinsics


@batched(1,0,0)
def pixel_to_uv(
    pixel: torch.Tensor,
    width: Union[int, torch.Tensor],
    height: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Args:
        pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space,  x range is (0, W - 1), y range is (0, H - 1)
        width (int | torch.Tensor): [...] image width(s)
        height (int | torch.Tensor): [...] image height(s)

    Returns:
        (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
    """
    if not torch.is_floating_point(pixel):
        pixel = pixel.float()
    uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel)
    return uv


@batched(1,0,0)
def uv_to_pixel(
    uv: torch.Tensor,
    width: Union[int, torch.Tensor],
    height: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Args:
        uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
        width (int | torch.Tensor): [...] image width(s)
        height (int | torch.Tensor): [...] image height(s)

    Returns:
        (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
    """
    pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5
    return pixel


@batched(1,0,0)
def pixel_to_ndc(
    pixel: torch.Tensor,
    width: Union[int, torch.Tensor],
    height: Union[int, torch.Tensor]
) -> torch.Tensor:
    """
    Args:
        pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
        width (int | torch.Tensor): [...] image width(s)
        height (int | torch.Tensor): [...] image height(s)

    Returns:
        (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
    """
    if not torch.is_floating_point(pixel):
        pixel = pixel.float()
    ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \
        + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device)
    return ndc


@batched(0,0,0)
def project_depth(
        depth: torch.Tensor,
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Project linear depth to depth value in screen space

    Args:
        depth (torch.Tensor): [...] depth value
        near (float | torch.Tensor): [...] near plane to clip
        far (float | torch.Tensor): [...] far plane to clip

    Returns:
        (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]
    """
    return (far - near * far / depth) / (far - near)


@batched(0,0,0)
def depth_buffer_to_linear(
        depth: torch.Tensor,
        near: Union[float, torch.Tensor],
        far: Union[float, torch.Tensor]
    ) -> torch.Tensor:
    """
    Linearize depth value to linear depth

    Args:
        depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
        near (float | torch.Tensor): [...] near plane to clip
        far (float | torch.Tensor): [...] far plane to clip

    Returns:
        (torch.Tensor): [...] linear depth
    """
    return near * far / (far - (far - near) * depth)


@batched(2, 2, 2, 2)
def project_gl(
    points: torch.Tensor,
    model: torch.Tensor = None,
    view: torch.Tensor = None,
    perspective: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Project 3D points to 2D following the OpenGL convention (except for row major matrice)

    Args:
        points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last 
            dimension is 4, the points are assumed to be in homogeneous coordinates
        model (torch.Tensor): [..., 4, 4] model matrix
        view (torch.Tensor): [..., 4, 4] view matrix
        perspective (torch.Tensor): [..., 4, 4] perspective matrix

    Returns:
        scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
            The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
        linear_depth (torch.Tensor): [..., N] linear depth
    """
    assert perspective is not None, "perspective matrix is required"

    if points.shape[-1] == 3:
        points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
    mvp = perspective if perspective is not None else torch.eye(4).to(points)
    if view is not None:
        mvp = mvp @ view
    if model is not None:
        mvp = mvp @ model
    clip_coord = points @ mvp.transpose(-1, -2)
    ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
    scr_coord = ndc_coord * 0.5 + 0.5
    linear_depth = clip_coord[..., 3]
    return scr_coord, linear_depth


@batched(2, 2, 2)
def project_cv(
    points: torch.Tensor,
    extrinsics: torch.Tensor = None,
    intrinsics: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Project 3D points to 2D following the OpenCV convention

    Args:
        points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
            dimension is 4, the points are assumed to be in homogeneous coordinates
        extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
        intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix

    Returns:
        uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
            The origin (0., 0.) is corresponding to the left & top
        linear_depth (torch.Tensor): [..., N] linear depth
    """
    assert intrinsics is not None, "intrinsics matrix is required"
    if points.shape[-1] == 3:
        points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
    if extrinsics is not None:
        points = points @ extrinsics.transpose(-1, -2)
    points = points[..., :3] @ intrinsics.transpose(-2, -1)
    uv_coord = points[..., :2] / points[..., 2:]
    linear_depth = points[..., 2]
    return uv_coord, linear_depth


@batched(2, 2, 2, 2)
def unproject_gl(
        screen_coord: torch.Tensor,
        model: torch.Tensor = None,
        view: torch.Tensor = None,
        perspective: torch.Tensor = None
    ) -> torch.Tensor:
    """
    Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)

    Args:
        screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
            The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
        model (torch.Tensor): [..., 4, 4] model matrix
        view (torch.Tensor): [..., 4, 4] view matrix
        perspective (torch.Tensor): [..., 4, 4] perspective matrix

    Returns:
        points (torch.Tensor): [..., N, 3] 3d points
    """
    assert perspective is not None, "perspective matrix is required"
    ndc_xy = screen_coord * 2 - 1
    clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1)
    transform = perspective
    if view is not None:
        transform = transform @ view
    if model is not None:
        transform = transform @ model
    transform = torch.inverse(transform)
    points = clip_coord @ transform.transpose(-1, -2)
    points = points[..., :3] / points[..., 3:]
    return points
    

@batched(2, 1, 2, 2)
def unproject_cv(
    uv_coord: torch.Tensor,
    depth: torch.Tensor,
    extrinsics: torch.Tensor = None,
    intrinsics: torch.Tensor = None
) -> torch.Tensor:
    """
    Unproject uv coordinates to 3D view space following the OpenCV convention

    Args:
        uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
            The origin (0., 0.) is corresponding to the left & top
        depth (torch.Tensor): [..., N] depth value
        extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
        intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix

    Returns:
        points (torch.Tensor): [..., N, 3] 3d points
    """
    assert intrinsics is not None, "intrinsics matrix is required"
    points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
    points = points @ torch.inverse(intrinsics).transpose(-2, -1)
    points = points * depth[..., None]
    if extrinsics is not None:
        points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
        points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
    return points


def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
    """
    Return the rotation matrices for one of the rotations about an axis
    of which Euler angles describe, for each value of the angle given.

    Args:
        axis: Axis label "X" or "Y or "Z".
        angle: any shape tensor of Euler angles in radians

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """

    cos = torch.cos(angle)
    sin = torch.sin(angle)
    one = torch.ones_like(angle)
    zero = torch.zeros_like(angle)

    if axis == "X":
        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
    elif axis == "Y":
        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
    elif axis == "Z":
        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
    else:
        raise ValueError("letter must be either X, Y or Z.")

    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))


def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
    """
    Convert rotations given as Euler angles in radians to rotation matrices.

    Args:
        euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
        convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
        raise ValueError("Invalid input euler angles.")
    if len(convention) != 3:
        raise ValueError("Convention must have 3 letters.")
    if convention[1] in (convention[0], convention[2]):
        raise ValueError(f"Invalid convention {convention}.")
    for letter in convention:
        if letter not in ("X", "Y", "Z"):
            raise ValueError(f"Invalid letter {letter} in convention string.")
    matrices = [
        euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
        for c in convention
    ]
    # return functools.reduce(torch.matmul, matrices)
    return matrices[2] @ matrices[1] @ matrices[0]


def skew_symmetric(v: torch.Tensor):
    "Skew symmetric matrix from a 3D vector"
    assert v.shape[-1] == 3, "v must be 3D"
    x, y, z = v.unbind(dim=-1)
    zeros = torch.zeros_like(x)
    return torch.stack([
        zeros, -z, y,
        z, zeros, -x,
        -y, x, zeros,
    ], dim=-1).reshape(*v.shape[:-1], 3, 3)


def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor):
    "Rotation matrix that rotates v1 to v2"
    I = torch.eye(3).to(v1)
    v1 = F.normalize(v1, dim=-1)
    v2 = F.normalize(v2, dim=-1)
    v = torch.cross(v1, v2, dim=-1)
    c = torch.sum(v1 * v2, dim=-1)
    K = skew_symmetric(v)
    R = I + K + (1 / (1 + c))[None, None] * (K @ K)
    return R


def _angle_from_tan(
    axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
    """
    Extract the first or third Euler angle from the two members of
    the matrix which are positive constant times its sine and cosine.

    Args:
        axis: Axis label "X" or "Y or "Z" for the angle we are finding.
        other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
            convention.
        data: Rotation matrices as tensor of shape (..., 3, 3).
        horizontal: Whether we are looking for the angle for the third axis,
            which means the relevant entries are in the same row of the
            rotation matrix. If not, they are in the same column.
        tait_bryan: Whether the first and third axes in the convention differ.

    Returns:
        Euler Angles in radians for each matrix in data as a tensor
        of shape (...).
    """

    i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
    if horizontal:
        i2, i1 = i1, i2
    even = (axis + other_axis) in ["XY", "YZ", "ZX"]
    if horizontal == even:
        return torch.atan2(data[..., i1], data[..., i2])
    if tait_bryan:
        return torch.atan2(-data[..., i2], data[..., i1])
    return torch.atan2(data[..., i2], -data[..., i1])


def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
    """
    Convert rotations given as rotation matrices to Euler angles in radians.
    NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)

    Args:
        matrix: Rotation matrices as tensor of shape (..., 3, 3).
        convention: Convention string of three uppercase letters.

    Returns:
        Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)
    """
    if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'):
        raise ValueError(f"Invalid convention {convention}.")
    if not matrix.shape[-2:] == (3, 3):
        raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
    
    i0 = 'XYZ'.index(convention[0])
    i2 = 'XYZ'.index(convention[2])
    tait_bryan = i0 != i2
    if tait_bryan:
        central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
    else:
        central_angle = torch.acos(matrix[..., i2, i2])

    # Angles in composition order
    o = [
        _angle_from_tan(
            convention[0], convention[1], matrix[..., i2, :], True, tait_bryan
        ),
        central_angle,
        _angle_from_tan(
            convention[2], convention[1], matrix[..., i0], False, tait_bryan
        ),
    ]
    return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1)


def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation

    Args:
        axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors

    Returns:
        torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
    """
    batch_shape = axis_angle.shape[:-1]
    device, dtype = axis_angle.device, axis_angle.dtype

    angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True)
    axis = axis_angle / angle

    cos = torch.cos(angle)[..., None, :]
    sin = torch.sin(angle)[..., None, :]

    rx, ry, rz = torch.split(axis, 3, dim=-1)
    zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device)
    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3))

    ident = torch.eye(3, dtype=dtype, device=device)
    rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K)
    return rot_mat


def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)

    Args:
        rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert

    Returns:
        torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices
    """
    quat = matrix_to_quaternion(rot_mat)
    axis_angle = quaternion_to_axis_angle(quat, eps=eps)
    return axis_angle


def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)

    Args:
        quaternion (torch.Tensor): shape (..., 4), the quaternions to convert

    Returns:
        torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions
    """
    assert quaternion.shape[-1] == 4
    norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True)
    axis = quaternion[..., 1:] / norm.clamp(min=eps)
    angle = 2 * torch.atan2(norm, quaternion[..., 0:1])
    return angle * axis


def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)

    Args:
        axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors

    Returns:
        torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters
    """
    axis = F.normalize(axis_angle, dim=-1, eps=eps)
    angle = torch.norm(axis_angle, dim=-1, keepdim=True)
    quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1)
    return quat


def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Convert 3x3 rotation matrix to quaternion (w, x, y, z)

    Args:
        rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert

    Returns:
        torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices
    """
    # Extract the diagonal and off-diagonal elements of the rotation matrix
    m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1)

    diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1)
    M = torch.tensor([
        [1, 1, 1],
        [1, -1, -1],
        [-1, 1, -1],
        [-1, -1, 1]
    ], dtype=rot_mat.dtype, device=rot_mat.device)
    wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5)
    _, max_idx = wxyz.max(dim=-1)
    xw = torch.sign(m21 - m12)
    yw = torch.sign(m02 - m20)
    zw = torch.sign(m10 - m01)
    yz = torch.sign(m21 + m12)
    xz = torch.sign(m02 + m20)
    xy = torch.sign(m01 + m10)
    ones = torch.ones_like(xw)
    sign = torch.where(
        max_idx[..., None] == 0,
        torch.stack([ones, xw, yw, zw], dim=-1),
        torch.where(
            max_idx[..., None] == 1,
            torch.stack([xw, ones, xy, xz], dim=-1),
            torch.where(
                max_idx[..., None] == 2,
                torch.stack([yw, xy, ones, yz], dim=-1),
                torch.stack([zw, xz, yz, ones], dim=-1)
            )
        )
    )
    quat = sign * wxyz
    quat = F.normalize(quat, dim=-1, eps=eps)
    return quat


def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """Converts a batch of quaternions (w, x, y, z) to rotation matrices
    
    Args:
        quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
    
    Returns:
        torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
    """
    assert quaternion.shape[-1] == 4
    quaternion = F.normalize(quaternion, dim=-1, eps=eps)
    w, x, y, z = quaternion.unbind(dim=-1)
    zeros = torch.zeros_like(w)
    I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device)
    xyz = quaternion[..., 1:]
    A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None]
    B = torch.stack([
        zeros, -z, y,
        z, zeros, -x,
        -y, x, zeros
    ], dim=-1).unflatten(-1, (3, 3))
    rot_mat = I + 2 * (A + w[..., None, None] * B)
    return rot_mat


def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
    """Spherical linear interpolation between two rotation matrices

    Args:
        rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
        rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
        t (torch.Tensor): scalar or shape (...,), the interpolation factor

    Returns:
        torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix
    """
    assert rot_mat_1.shape[-2:] == (3, 3)
    rot_vec_1 = matrix_to_axis_angle(rot_mat_1)
    rot_vec_2 = matrix_to_axis_angle(rot_mat_2)
    if isinstance(t, Number):
        t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device)
    rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2
    rot_mat = axis_angle_to_matrix(rot_vec)
    return rot_mat


def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
    """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.

    Args:
        ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
        ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
        t (torch.Tensor): scalar or shape (...,), the interpolation factor

    Returns:
        torch.Tensor: shape (..., 4, 4), the interpolated camera pose
    """
    return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t))


def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]):
    """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.

    Args:
        ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
        ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
        t (torch.Tensor): scalar or shape (...,), the interpolation factor

    Returns:
        torch.Tensor: shape (..., 4, 4), the interpolated camera pose
    """
    return interpolate_extrinsics(view1, view2, t)


def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]):
    assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4)
    if isinstance(t, Number):
        t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device)
    pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3]
    rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t)
    transform = torch.cat([rot, pos[..., None]], dim=-1)
    transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2)
    return transform


def extrinsics_to_essential(extrinsics: torch.Tensor):
    """
    extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`

    Args:
        extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix

    Returns:
        (torch.Tensor): [..., 3, 3] essential matrix
    """
    assert extrinsics.shape[-2:] == (4, 4)
    R = extrinsics[..., :3, :3]
    t = extrinsics[..., :3, 3]
    zeros = torch.zeros_like(t)
    t_x = torch.stack([
        zeros, -t[..., 2], t[..., 1],
        t[..., 2], zeros, -t[..., 0],
        -t[..., 1], t[..., 0], zeros
    ]).reshape(*t.shape[:-1], 3, 3)
    return R @ t_x


def to4x4(R: torch.Tensor, t: torch.Tensor):
    """
    Compose rotation matrix and translation vector to 4x4 transformation matrix

    Args:
        R (torch.Tensor): [..., 3, 3] rotation matrix
        t (torch.Tensor): [..., 3] translation vector

    Returns:
        (torch.Tensor): [..., 4, 4] transformation matrix
    """
    assert R.shape[-2:] == (3, 3)
    assert t.shape[-1] == 3
    assert R.shape[:-2] == t.shape[:-1]
    return torch.cat([
        torch.cat([R, t[..., None]], dim=-1),
        torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4)
    ], dim=-2)


def rotation_matrix_2d(theta: Union[float, torch.Tensor]):
    """
    2x2 matrix for 2D rotation

    Args:
        theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)

    Returns:
        (torch.Tensor): (..., 2, 2) rotation matrix
    """
    if isinstance(theta, float):
        theta = torch.tensor(theta)
    return torch.stack([
        torch.cos(theta), -torch.sin(theta),
        torch.sin(theta), torch.cos(theta),
    ], dim=-1).unflatten(-1, (2, 2))


def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None):
    """
    3x3 matrix for 2D rotation around a center
    ```
       [[Rxx, Rxy, tx],
        [Ryx, Ryy, ty],
        [0,     0,  1]]
    ```
    Args:
        theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
        center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
        
    Returns:
        (torch.Tensor): (..., 3, 3) transformation matrix
    """
    if isinstance(theta, float):
        theta = torch.tensor(theta)
        if center is not None:
            theta = theta.to(center)
    if center is None:
        center = torch.zeros(2).to(theta).expand(*theta.shape, -1)
    R = rotation_matrix_2d(theta)
    return torch.cat([
        torch.cat([
            R, 
            center[..., :, None] - R @ center[..., :, None],
        ], dim=-1),
        torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1),
    ], dim=-2)


def translate_2d(translation: torch.Tensor):
    """
    Translation matrix for 2D translation
    ```
       [[1, 0, tx],
        [0, 1, ty],
        [0, 0,  1]]
    ```
    Args:
        translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
    
    Returns:
        (torch.Tensor): (..., 3, 3) transformation matrix
    """
    return torch.cat([
        torch.cat([
            torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
            translation[..., None],
        ], dim=-1),
        torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
    ], dim=-2)


def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None):
    """
    Scale matrix for 2D scaling
    ```
       [[s, 0, tx],
        [0, s, ty],
        [0, 0,  1]]
    ```
    Args:
        scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
        center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)

    Returns:
        (torch.Tensor): (..., 3, 3) transformation matrix
    """
    if isinstance(scale, float):
        scale = torch.tensor(scale)
        if center is not None:
            scale = scale.to(center)
    if center is None:
        center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1)
    return torch.cat([
        torch.cat([
            scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1),
            center[..., :, None] - center[..., :, None] * scale[..., None, None],
        ], dim=-1),
        torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1),
    ], dim=-2)


def apply_2d(transform: torch.Tensor, points: torch.Tensor):
    """
    Apply (3x3 or 2x3) 2D affine transformation to points
    ```
        p = R @ p + t
    ```
    Args:
        transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
        points (torch.Tensor): (..., N, 2) points to transform

    Returns:
        (torch.Tensor): (..., N, 2) transformed points
    """
    assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3"
    assert points.shape[-1] == 2, "points must be 2D"
    return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2] 