from sklearn.base import BaseEstimator
from jax import grad, jit, jacfwd, vmap, partial, hessian
import jax.numpy as jnp
import jax.scipy as jsp
import jax.lax as lax
import numpy as np
from jax import jacfwd
import matplotlib.pyplot as plt
from jax.ops import index, index_add, index_update


class QuadraticClassifier(BaseEstimator):
    def __init__(self, dim, lmbda=1, norm="nuc"):
        super().__init__()
        self.lmbda = lmbda
        self.dim = dim
        self.norm = norm

        # xAx + bx + c
        self.A = jnp.zeros((self.dim, self.dim), dtype=jnp.float32)
        self.b = jnp.zeros(self.dim, dtype=jnp.float32)
        self.c = 0.0

        

    def projected_GD(self, grad_A, grad_b, grad_c, step_sizes):

        step_A, step_b, step_c = step_sizes

        # Gradient step
        self.A -= step_A * grad_A
        self.b -= step_b * grad_b
        self.c -= step_c * grad_c

        # project A
        if self.norm == "nuc":
            self.A = nuclear_project(self.A, self.lmbda, self.dim)
        if self.norm == "fro":
            self.A = frobenius_project(self.A, self.lmbda)

    
    def fit(self, X, y, n_epoch=10000, batch_size=10, plot=False, fname=None):

        n = X.shape[0]
        batch_grad = jit(grad(lambda A,b,c: batch_loss(A, b, c, X, y), argnums=(0, 1, 2)))
        b_l = jit(lambda A, b, c: batch_loss(A, b, c, X, y))

        train_losses = np.zeros((n_epoch * (n // batch_size), 1))
        A_best = self.A
        b_best = self.b
        c_best = self.c
        f_best = float('inf')
        f_k = -1

        L = np.average([np.linalg.norm(x)**4 for x in X])
        done = False

        A_prev = jnp.array(self.A, copy=True)
        b_prev = jnp.array(self.b, copy=True)
        c_prev = jnp.array(self.c, copy=True)

        m_A = 0
        m_b = 0
        m_c = 0  

        l_curr = 1

        for i in range(n_epoch):


            l_nxt = 0.5*(1 + (1+4*l_curr**2)**(0.5))
            t = (l_curr - 1)/ l_nxt
            v_A = A_prev + t*m_A
            v_b = b_prev + t*m_b
            v_c = c_prev + t*m_c
            l_curr = l_nxt
            grad_A, grad_b, grad_c = batch_grad(v_A, v_b, v_c)

            step_A = 1.0/L
            step_b = 1.0/L
            step_c = 1.0/L


            f_k = b_l(self.A, self.b, self.c)

            # param_updates
            self.projected_GD(grad_A, grad_b, grad_c,
                                (step_A, step_b, step_c))

            f_curr = b_l(self.A, self.b, self.c)
            
        
            m_A = (1*self.A - A_prev)
            A_prev = index_update(A_prev, index[:, :], self.A)
            

            m_b = (1*self.b - b_prev)
            b_prev = index_update(b_prev, index[:], self.b)

            m_c = 1*self.c - c_prev
            c_prev = 1*self.c

            if L*(jnp.linalg.norm(m_A)**2 + jnp.linalg.norm(m_b)**2 + m_c**2) < 1e-8:
                done = True
                print("JUST HIT TOLERANCE BOIIII")
                break

            if f_curr > f_k:
                print("RESTART BOII")
                m_A = 0
                m_b = 0
                m_c = 0
                l_curr = 1

            if plot:
                train_losses[i, 0] = b_l(self.A, self.b, self.c)
            if done:
                break
        

        if plot:
            plt.yscale('log')
            plt.plot(train_losses[:, 0], label="iterates")
            plt.legend()
            plt.savefig("figs/{}.png".format(fname), format='png')
            plt.close()

    def predict(self, X):
        return jnp.sign(batch_classifier(self.A, self.b, self.c, X))

    def score(self, X, y):
        preds = self.predict(X)
        return 1 - jnp.sum(jnp.abs(preds - y)) / (2 * len(X))


@jit
def classifier(A, b, c, x):
    return jnp.dot(x, jnp.dot(A, x)) + jnp.dot(b, x) + c


batch_classifier = jit(vmap(classifier, in_axes=(None, None, None, 0)))


@jit
def hinge_loss(y, x):
    return jnp.maximum(0, 1 - y * x)


@jit
def smoothed_hinge_loss(y, x):
    return jnp.where(y * x <= 0, 0.5 - y * x,
                     jnp.where(y * x < 1, 0.5 * (1 - y * x)**2, 0))


@jit
def loss(A, b, c, x, y):
    return smoothed_hinge_loss(y, classifier(A, b, c, x))


@jit
def batch_loss(A, b, c, x, y):
    preds = batch_classifier(A, b, c, x)
    return jnp.mean(smoothed_hinge_loss(y, preds))


@jit
def gradient(A, b, c, x, y):
    return grad(loss, argnums=(0, 1, 2))(A, b, c, x, y)


@partial(jit, static_argnums=(1, 2))
def project(v, radius, dim):
    mu = lax.sort(v)
    cumul_sum = jnp.divide(
        lax.cumsum(mu, reverse=True) - radius, jnp.arange(dim, 0, -1))
    rho = jnp.amin(jnp.where(mu > cumul_sum, jnp.arange(dim), dim))
    theta = cumul_sum[rho]
    return jnp.maximum(v - theta, 0)


@jit
def svd(mat):
    return jsp.linalg.svd(mat, full_matrices=False)


@partial(jit, static_argnums=(1, 2))
def nuclear_project(mat, radius, dim):
    U, s, Vt = svd(mat)
    s = jnp.where(jnp.sum(s) > radius, project(s, radius, dim), s)
    return jnp.dot(U, jnp.multiply(s, Vt))


@partial(jit, static_argnums=1)
def frobenius_project(mat, radius):
    norm_mat = jnp.linalg.norm(mat, ord='fro')
    return jnp.where(norm_mat > radius, radius / norm_mat * mat, mat)
