import os
import sys

import einops
import torch

root = os.path.abspath(".")
sys.path.insert(0, root)


def mean_w_mask(a, mask, keepdim=True):

    mask = mask[..., None]
    num_elements = torch.sum(mask, dim=-2, keepdim=True)
    num_elements = torch.where(num_elements == 0, torch.tensor(1.0), num_elements)
    a_masked = torch.masked_fill(a, ~mask, 0.0)
    mean = torch.sum(a_masked, dim=-2, keepdim=True) / num_elements
    mean = torch.masked_fill(mean, num_elements == 0, 0.0)
    if not keepdim:
        mean = einops.rearrange(mean, "... () d -> ... d")
    return mean


def kabsch_align_ind(mobile, target, mask=None, ret_both=False):

    if mask is None:
        mask = torch.ones(mobile.shape[:-1]).bool()

    mobile, target = mobile[None, ...], target[None, ...]
    mobile_aligned = kabsch_align(mobile, target)

    if ret_both:
        return mobile_aligned[0], target[0]
    return mobile_aligned[0]


def kabsch_align(mobile, target, mask=None):

    if mask is None:
        mask = torch.ones(mobile.shape[:-1]).bool()

    mean_mobile = mean_w_mask(mobile, mask, keepdim=True)
    mean_target = mean_w_mask(target, mask, keepdim=True)

    mobile_centered = mobile - mean_mobile
    target_ceneterd = target - mean_target

    mobile_centered = torch.masked_fill(mobile_centered, ~mask[..., None], 0.0)
    target_ceneterd = torch.masked_fill(target_ceneterd, ~mask[..., None], 0.0)

    R = _find_rot_alignment(mobile_centered, target_ceneterd, mask)

    mobile_aligned = (
        torch.matmul(
            R,
            mobile_centered.transpose(-2, -1),
        ).transpose(-2, -1)
        + mean_target
    )

    mobile_aligned = torch.masked_fill(mobile_aligned, ~mask[..., None], 0.0)
    return mobile_aligned


def _find_rot_alignment(A, B, mask=None):

    if mask is None:
        mask = torch.ones(A.shape[:-1]).bool()

    sh = mean_w_mask(A, mask, keepdim=True).shape
    assert torch.allclose(
        mean_w_mask(A, mask, keepdim=True),
        torch.zeros(sh, device=A.device),
        atol=1e-4,
        rtol=1e-4,
    )
    assert torch.allclose(
        mean_w_mask(B, mask, keepdim=True),
        torch.zeros(sh, device=B.device),
        atol=1e-4,
        rtol=1e-4,
    )
    assert A.shape == B.shape

    mask = mask[..., None]
    A = torch.masked_fill(A, ~mask, 0.0)
    B = torch.masked_fill(B, ~mask, 0.0)

    H = torch.matmul(A.transpose(-2, -1), B)

    U, S, Vt = torch.linalg.svd(H.to(torch.float32), full_matrices=True)

    R = torch.matmul(
        Vt.transpose(-2, -1),
        U.transpose(-2, -1),
    )

    det_R = torch.linalg.det(R.to(torch.float32))
    SS = torch.eye(3, device=R.device).repeat(A.shape[0], 1, 1)
    SS[:, -1, -1] = torch.where(
        det_R < 0,
        torch.tensor(-1.0, device=R.device),
        torch.tensor(1.0, device=R.device),
    )
    R_aux = torch.matmul(Vt.transpose(-2, -1), SS)
    R = torch.matmul(R_aux, U.transpose(-2, -1))

    return R
