"""Some utilities for low rank positive semi-definite matrices."""
import torch


def compute_sq_frobenius_distance(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """Computes ||A^TA-B^TB||_F.

    A.shape = [rk_A, n]
    B.shape = [rk_B, n]
    """
    AA = torch.einsum('ij,kj->ik', A, A)
    AA_F = torch.einsum('ij,ij->', AA, AA)

    BB = torch.einsum('ij,kj->ik', B, B)
    BB_F = torch.einsum('ij,ij->', BB, BB)
    
    AB = torch.einsum('aj,bj->ab', A, B)
    AB_F = torch.einsum('ij,ij->', AB, AB)

    return AA_F + BB_F - 2.0 * AB_F
