from quad_jax import *
import numpy as np
from pytest import approx
from sklearn.metrics import hinge_loss as sk_hinge_loss

def test_classifier():
    A = np.random.randn(2,2)
    b = np.random.randn(2)
    c = np.random.randn(1)

    x = np.random.randn(2)
    val = np.dot(x, A@x) + np.dot(b, x) + c
    jax_val = classifier(A, b, c, x)

    assert val == approx(jax_val)

def test_project():

    a = np.array([1.4, 0.4])
    b = np.array([0.9, 0.5])
    c = np.array([0.5, 0.5])

    jx_a = project(a, 1, 2)
    jx_b = project(b, 1, 2)
    jx_c = project(c, 1, 2)

    print(jx_a, jx_b, jx_c)

    assert np.array(jx_a) == approx(np.array([1, 0]), abs=1e-7)
    assert np.array(jx_b) == approx(np.array([0.7, 0.3]))
    assert np.array(jx_c) == approx(c)

def test_nuclear_project():

    mat = np.random.randn(2,2)

    l1 = 0.5
    l2 = 10
    l3 = 0.1

    p1 = nuclear_project(mat, l1, 2)
    p2 = nuclear_project(mat, l2, 2)
    p3 = nuclear_project(mat, l3, 2)

    tol = 1e-6

    assert (np.linalg.norm(p1, ord='nuc')) < l1 + tol
    assert (np.linalg.norm(p2, ord='nuc')) < l2 + tol
    assert (np.linalg.norm(p3, ord='nuc')) < l3 + tol

def test_gradient():

    x = np.random.randn(5)
    y = 1
    x2 = np.random.randn(5)
    y2 = -1

    ga, gb, gc = gradient(np.zeros((5,5)), np.zeros(5), 0.0, x, y)
    ga2, gb2, gc2 = gradient(np.zeros((5,5)), np.zeros(5), 0.0, x2, y2)
    
    assert np.outer(x, x) == approx(np.array(-y*ga))
    assert np.outer(x2, x2) == approx(np.array(-y2*ga2))


def test_quad_fit():

    n = 500
    d = 30

    u, v = np.random.randn(2, d)
    gd_truth_A = np.outer(u, v)

    X = np.random.randn(n, d)
    y = np.array([np.sign(classifier(gd_truth_A, np.zeros(d, dtype=np.float32), 0.0, X[i])) for i in range(n)])

    quad = QuadraticClassifier(dim=d, lmbda=np.linalg.norm(gd_truth_A, ord='nuc'))

    quad.fit(X, y, n_epoch=10)

    assert quad.score(X, y) >= 0.96

