import torch
import torch.linalg


def normalize_vector(v, dim, eps=0):
    return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps)


def project_v2v(v, e, dim=-1):
    return (e * v).sum(dim=dim, keepdim=True) * e


def construct_3d_basis(center, p1, p2):
    v1 = p1 - center    # (*, 3)
    e1 = normalize_vector(v1, dim=-1)

    v2 = p2 - center    # (*, 3)
    u2 = v2 - project_v2v(v2, e1, dim=-1)
    e2 = normalize_vector(u2, dim=-1)

    e3 = torch.cross(e1, e2, dim=-1)    # (*, 3)

    mat = torch.cat([
        e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1)
    ], dim=-1)  # (*, 3, 3_index)
    return mat


def orthogonalize_matrix(R):
    repr_6d = torch.cat([R[..., 0], R[..., 1]], dim=-1) # (..., 6)
    return repr_6d_to_rotation_matrix(repr_6d)


def local_to_global(p, R, t=None):
    assert p.size(-1) == 3
    p_size = p.size()
    N = p_size[0]

    p = p.view(N, -1, 3).transpose(-1, -2)   # (N, *, 3) -> (N, 3, *)
    q = torch.matmul(R, p)  # (N, 3, *)
    if t is not None:
        q = q + t.unsqueeze(-1)
    q = q.transpose(-1, -2).reshape(p_size)     # (N, 3, *) -> (N, *, 3) -> (N, ..., 3)
    return q


def global_to_local(q, R, t=None):
    assert q.size(-1) == 3
    q_size = q.size()
    N = q_size[0]

    q = q.view(N, -1, 3).transpose(-1, -2)   # (N, *, 3) -> (N, 3, *)
    if t is not None:
        p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1)))  # (N, 3, *)        
    else:
        p = torch.matmul(R.transpose(-1, -2), q)
    p = p.transpose(-1, -2).reshape(q_size)     # (N, 3, *) -> (N, *, 3) -> (N, ..., 3)
    return p


def apply_rotation(R, v):
    return local_to_global(v, R, t=None)


def apply_inverse_rotation(R, v):
    return global_to_local(v, R, t=None)


def repr_6d_to_rotation_matrix(x):
    a1, a2 = x[..., 0:3], x[..., 3:6]
    b1 = normalize_vector(a1, dim=-1)
    b2 = normalize_vector(a2 - project_v2v(a2, b1, dim=-1), dim=-1)
    b3 = torch.cross(b1, b2, dim=-1)

    mat = torch.cat([
        b1.unsqueeze(-1), b2.unsqueeze(-1), b3.unsqueeze(-1)
    ], dim=-1)  # (N, L, 3, 3_index)
    return mat


def compose_rotation(R1, R2):
    R_new = torch.matmul(R1, R2)    # (N, L, 3, 3)
    return R_new

