#!/usr/bin/python

import numpy as np
import torch
from torch import Tensor

# Input: expects 3xN matrix of points
# Returns R,t
# R = 3x3 rotation matrix
# t = 3x1 column vector

def rigid_transform_3D(A, B):
    assert A.shape == B.shape

    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")

    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")

    # find mean column wise
    centroid_A = np.mean(A, axis=1)
    centroid_B = np.mean(B, axis=1)

    # ensure centroids are 3x1
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ np.transpose(Bm)

    # sanity check
    #if linalg.matrix_rank(H) < 3:
    #    raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))

    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0:
        #print("det(R) < R, reflection detected!, correcting for it ...")
        Vt[2,:] *= -1
        R = Vt.T @ U.T

    t = -R @ centroid_A + centroid_B

    return R, t

def align(B, A):
    if isinstance(A, Tensor):
        A = A.numpy()
    if isinstance(B, Tensor):
        B = B.numpy()
    B, A = B.T, A.T
    ret_R, ret_t = rigid_transform_3D(A, B)
    B2 = (ret_R@A) + ret_t
    return B.T, B2.T

def align_com(x_gen, x_scaff):
    """
    x_gen, x_scaff: torch.Tensor (n, 3)
    """
    com_gen = torch.mean(x_gen, dim=0, keepdim=True)
    com_scaff = torch.mean(x_scaff, dim=0, keepdim=True)

    return x_scaff - com_scaff + com_gen