import torch

def Jtest(X, p):
    d = X.shape[0]
    J = torch.eye(d).to(torch.float32)
    J[p:, p:] = -1*torch.eye(d-p)
    err = torch.mm(X.T, torch.mm(J, X))-J
    err = torch.mean(torch.abs(err))
    return err
