from quad_jax import nuclear_project, project
import numpy as np 

dim = 50
A = np.zeros((dim, dim))
A[0, 0] = 10

mat = nuclear_project(A, 0.5*np.linalg.norm(A, ord='nuc'), dim)

print(np.abs(mat -  A))

print(np.allclose(mat, A))
