import numpy as np


def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# This implements a stochastic oracle for the function
# Q(v, w) = 0.5*v^T\nabla^2F(w)v - \nabla F(w)^Tv
# F is the associated function
# X,Y are the data matrices
# w is the input to F
# mu is the regularization parameter for F
# v is the vector for hessian vector product (HVP) for F
# if order = 0 simply return Q(v, w)
# if order = 1 return \nabla_v Q(v,w) given by
# \nabla^2F(w)v - \nabla F(w) using the HVP for F
# This is because HVP might be implementable in O(d)
# instead of O(d^2) for functions like logistic regression


def quadraticWithHVP(F, X, Y, w, mu, v, order, index=0):

    # using order=3 to access the HVP
    gradF, hvpF = F(X, Y, w, mu, v, order=3, index=index)

    if order == 0:
        value = 0.5 * v.T.dot(hvpF).item(0) - gradF.T.dot(v).item(0)
        return value

    elif order == 1:  # return gradient using HVP
        # value = 0.5 * v.T.dot(hvpF).item(0) - gradF.T.dot(v).item(0)
        grad = hvpF - gradF
        return grad

    else:
        raise ValueError("The argument \"order\" should be between 0 and 1.")

# This is a custom order stochastic oracle for Logistic Regression
# It also implements a hessian vector product (HVP) oracle
# X,Y are the data matrices
# w is the logistic regression parameter
# v is the vector for the hessian vector product
# order = 0, 1, 2 correspond to usual stochastic oracles
# order = 3 corresponds to the HVP oracle


def logistic(X, Y, w, mu, v=None, order=3, index=0):

    n, d = X.shape
    i = index

    if order == 0:
        x_i = np.asmatrix(X[i]).T
        value = -Y[i, 0] * w.T.dot(x_i).item(0) + \
            np.log(1 + np.exp(Y[i, 0]*w.T.dot(x_i).item(0))
                   ) + mu * 0.5 * (np.linalg.norm(w) ** 2)
        return value

    elif order == 1:
        x_i = np.asmatrix(X[i]).T
        value = -Y[i, 0] * w.T.dot(x_i).item(0) + \
            np.log(1 + np.exp(Y[i, 0]*w.T.dot(x_i).item(0))
                   ) + mu * 0.5*(np.linalg.norm(w) ** 2)
        grad = -Y[i, 0] * x_i * \
            (1 - sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))) + mu * w
        return value, grad

    elif order == 2:
        x_i = np.asmatrix(X[i]).T
        value = -Y[i, 0] * w.T.dot(x_i).item(0) + \
            np.log(1 + np.exp(Y[i, 0]*w.T.dot(x_i).item(0))
                   ) + mu * 0.5 * (np.linalg.norm(w) ** 2)
        grad = -Y[i, 0] * x_i * \
            (1 - sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))) + mu * w
        hess = x_i.dot(x_i.T) * sigmoid(Y[i, 0] * w.T.dot(x_i).item(0)) * (
            1 - sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))) + mu * np.identity(d)
        return value, grad, hess

    elif order == 3:  # Hessian vector product
        x_i = np.asmatrix(X[i]).T
        prob = sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))
        grad = -Y[i, 0] * x_i * \
            (1 - prob) + mu * w
        hvp = x_i * (x_i.T.dot(v).item(0)) * prob * (1 - prob) + mu * v
        return grad, hvp

    elif order == 7:  # Hessian vector product
        x_i = np.asmatrix(X[i]).T
        prob = sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))
        hvp = x_i * (x_i.T.dot(v).item(0)) * prob * (1 - prob) + mu * v
        return hvp

    elif order == 4:  # return full hessian
        hess = 0
        for i in range(n):
            x_i = np.asmatrix(X[i]).T
            hess += x_i.dot(x_i.T) * sigmoid(Y[i, 0] * w.T.dot(x_i).item(0)) * (
                1 - sigmoid(Y[i, 0] * w.T.dot(x_i).item(0))) + mu*np.identity(d)
        return hess/n

    elif order == 5:  # returns overall loss
        loss = -np.mean(np.log(sigmoid(np.multiply(X.dot(w), Y)))
                        ) + mu * 0.5 * np.linalg.norm(w) ** 2
        return loss

    elif order == 6:  # returns overall grad
        grad = - np.asmatrix(np.mean(np.multiply(np.multiply(X, Y), (1 -
                                                                     sigmoid(np.multiply(X.dot(w), Y)))), axis=0)).T + mu * w
        return grad
    else:
        raise ValueError("The argument \"order\" should be between 0 and 6.")
