import numpy as np
from scipy import sparse, special
from scipy.sparse import linalg as splinalg
from sklearn.utils.extmath import safe_sparse_dot

from copt.utils import safe_sparse_add, njit, prange

class HingeLoss:
    def __init__(self, A, b, stochastic = False, batchsize = None):
        self.A = A
        self.b = b
        self.stochastic = stochastic
        self.batchsize = batchsize
    
    def __call__(self, x):
        return self.f_grad(x, return_gradient=False)
        
    def get_loss_and_grad(self, A, b, x, return_gradient = True):
        n_samples, n_features = A.shape

        z = b * safe_sparse_dot(A, x, dense_output=True).ravel()
        loss = np.maximum(0, 1 - z).mean()
        if not return_gradient:
            return loss

        ind = (1 - z) > 0
        dia = sparse.dia_matrix((ind*b, 0), shape=(n_samples, n_samples))
        grad_elements = -1 * safe_sparse_dot(
            sparse.dia_matrix((ind*b, 0), shape=(n_samples, n_samples)),
            A)
        grad = np.asarray(np.mean(grad_elements, axis=0)).squeeze()
        return loss, grad


    def f_grad(self, x, return_gradient=True):
        if self.stochastic:
            return self.f_sto_grad(x, return_gradient)
        else:
            return self.f_det_grad(x, return_gradient)

    def f_det_grad(self, x, return_gradient=True):
        return self.get_loss_and_grad(self.A, self.b, x, return_gradient)

    def f_sto_grad(self, x, return_gradient=True):
        n_samples, n_features = self.A.shape
        ind = np.random.choice(np.arange(n_samples), (self.batchsize,))
        b_sto = self.b[ind]
        A_sto = self.A[ind]
        return self.get_loss_and_grad(A_sto, b_sto, x, return_gradient)
