import torch


class ProcrustesSolver:
    
    @staticmethod
    def apply(grads, unit_scale=False):
        
        assert len(grads.shape) == 3, \
            f"Invalid shape of 'grads': {grads.shape}. Only 3D tensors are applicable"
        
        with torch.no_grad():

            cov_grad_matrix_e = torch.matmul(grads.permute(0, 2, 1), grads)
            cov_grad_matrix_e = cov_grad_matrix_e.mean(0)

            singulars, basis = torch.symeig(cov_grad_matrix_e, eigenvectors=True)
            tol = torch.max(singulars) * max(cov_grad_matrix_e.shape[-2:]) * torch.finfo().eps
            rank = sum(singulars > tol)

            order = torch.argsort(singulars, dim=-1, descending=True)
            singulars, basis = singulars[order][:rank], basis[:, order][:, :rank]

            if unit_scale: 
                weights = basis
            else:
                weights = basis * torch.sqrt(singulars[-1]).view(1, -1) 
            weights = weights / torch.sqrt(singulars).view(1, -1)
            # weights = weights / singulars.view(1, -1)
            weights = torch.matmul(weights, basis.T)
            grads = torch.matmul(grads, weights.unsqueeze(0))

            return grads, weights, singulars