from quad_jax import *
import numpy as np
from sklearn.metrics import hinge_loss as sk_hinge_loss
from utils import anisotropize


@profile
def test():
    
    d = 100
    n = d * int(np.ceil(np.log(d))) * 5

    u, v = np.random.randn(2, d)
    gd_truth_A = np.outer(u, v)

    X = np.random.randn(n, d)
    X = anisotropize(X, 0.7)
    y = np.array([np.sign(classifier(gd_truth_A, X[i])) for i in range(n)])

    quad = QuadraticClassifier(dim=d, lmbda=5*np.linalg.norm(gd_truth_A, ord='nuc'))

    quad.fit(X, y, n_epoch=25000, plot=True, fname="frank_wolfe_cond_pgd_ada")

    print(quad.score(X, y)) 


test()
