import numpy as np
from scipy.spatial.transform import Rotation
from torch import Tensor


def compute_optimal_transformation(A, B):
    """
    Finds the 3D transformation (translation + rotation) that minimizes the RMSD between A and B.

    Parameters:
        A, B (Nx3 matrices): two sets of data points to be aligned

    Returns:
        R (3x3 matrix): optimal rotation matrix
        t (3x1 vector): optimal translation vector
    """
    if isinstance(A, Tensor):
        A = A.numpy()
    if isinstance(B, Tensor):
        B = B.numpy()

    assert A.shape == B.shape
    assert A.shape[1] == B.shape[1] == 3

    N = A.shape[0]

    # find mean row-wise: 1x3
    centroid_A = np.mean(A, axis=0, keepdims=True)
    centroid_B = np.mean(B, axis=0, keepdims=True)

    # center points
    A_centered = A - centroid_A
    B_centered = B - centroid_B

    R_object, rssd = Rotation.align_vectors(A_centered, B_centered)
    R = R_object.as_matrix()

    t = -R @ centroid_A.squeeze() + centroid_B

    return R, t


def align_point_sets(A, B):
    """
    aligns two sets of points.

    Parameters:
        A, B (Nx3 matrices): two sets of data points to be aligned

    Returns: 
        A, B_aligned (Nx3 matrices): the aligned points
    """
    if isinstance(A, Tensor):
        A = A.numpy()
    if isinstance(B, Tensor):
        B = B.numpy()

    R, t = compute_optimal_transformation(A, B)
    print(R,t)
    B_aligned = (R @ B.T + t.T).T

    return A, B_aligned


def compute_rmsd_aligned(A, B):
    """
    Finds the 3D transformation that minimizes the RMSD between A and B and returns this RMSD value
    A, B: Nx3 matrices
    """
    if isinstance(A, Tensor):
        A = A.numpy()
    if isinstance(B, Tensor):
        B = B.numpy()

    assert A.shape == B.shape
    assert A.shape[1] == B.shape[1] == 3

    N = A.shape[0]

    # find mean row-wise: 1x3
    centroid_A = np.mean(A, axis=0, keepdims=True)
    centroid_B = np.mean(B, axis=0, keepdims=True)

    # center points
    A_centered = A - centroid_A
    B_centered = B - centroid_B

    R, rssd = Rotation.align_vectors(A_centered, B_centered)

    return np.sqrt(1/N) * rssd
