from quad_jax import *
import numpy as np
from sklearn.metrics import hinge_loss as sk_hinge_loss


@profile
def test():
    n = 500
    d = 50

    u, v = np.random.randn(2, d)
    gd_truth_A = np.outer(u, v)

    X = np.random.randn(n, d)
    y = np.array([np.sign(classifier(gd_truth_A, np.zeros(d, dtype=np.float32), 0.0, X[i])) for i in range(n)])

    quad = QuadraticClassifier(dim=d, lmbda=np.linalg.norm(gd_truth_A, ord='nuc'))

    quad.fit(X, y, n_epoch=5, batch_size=20)

    print(quad.score(X, y)) 


test()
