import torch
from torch_scatter import scatter_sum, scatter_mean


def kabsch_algorithm_batch(P, Q, batch_id):
    mu_P = scatter_mean(P, batch_id, dim=0)
    mu_Q = scatter_mean(Q, batch_id, dim=0)
    
    P_centered = P - mu_P[batch_id]
    Q_centered = Q - mu_Q[batch_id]
    
    H = scatter_sum(P_centered[:,:,None]*Q_centered[:,None], batch_id, dim=0)
    U, S, Vt = torch.linalg.svd(H)
    R = torch.matmul(Vt.permute(0,2,1), U.permute(0,2,1))
    Vt[:, -1, :] *= torch.det(R)[:,None]
    R = torch.matmul(Vt.permute(0,2,1), U.permute(0,2,1))
    t = mu_Q - torch.matmul(R, mu_P[:,:,None])[:,:,0]    
    return R, t

def batch_rmsd(P, Q, batch_id, MSE=False, bb_sc=False):
    if len(P.shape) == 3:
        with torch.no_grad():
            R, t = kabsch_algorithm_batch(P[:,1], Q[:,1], batch_id)
        P = torch.einsum('nki, nij->nkj', P, R.permute(0,2,1)[batch_id]) + t[batch_id][:,None]

        P_bb, Q_bb = P[:, :4], Q[:, :4]
        P_sc, Q_sc = P[:, 4:], Q[:, 4:]

        if bb_sc:
            pairs = [(P_bb, Q_bb), (P_sc, Q_sc), (P, Q)]
            rmsd = [torch.sqrt(scatter_mean(torch.sum((p - q)**2, dim=(1,2)), batch_id)).tolist() for p, q in pairs]
        else:
            pairs = [(P_bb, Q_bb)]
            rmsd = [torch.sqrt(scatter_mean(torch.sum((p - q)**2, dim=(1,2)), batch_id)).tolist() for p, q in pairs] + [[0.0], [0.0]]
        return rmsd