import torch

def OrthProj(X):
    P, S, Q = torch.linalg.svd(X)
    R = torch.mm(P, Q.T)
    nuclear_norm_X = torch.sum(torch.diag(S))
    return R
