import torch

def Jtest(X, p):
    X = X.to(torch.float64)
    Rnum = X.shape[0]
    Jerr = torch.zeros(Rnum)
    d = X.shape[-1]
    J = torch.eye(d).to(torch.float64).to(X.device)
    J[p:, p:] = -1 * torch.eye(d - p)
    for i in range(Rnum):
        temp = torch.mm(X[i,:,:].T, torch.mm(J, X[i,:,:]))-J
        Jerr[i] = torch.mean(torch.abs(temp))
    return torch.mean(Jerr)
