import numpy as np
from function import quadraticWithHVP, logistic
import copy
from data import loada9a

X_train, Y_train, X_test, Y_test = loada9a()
n, d = X_train.shape

# Inexact damped newton method as in the DiSCO paper
# We solve the quadratic subproblem associated
# with hessian inversion using Local SGD
# F is the objective function's stochastic oracle
# X, Y are the data matrices
# T is the horizon for the newton algorithms
# alpha is a step-size parameter for teh newton steps
# M, K, R are the local SGD parameters (see below)
# lr is the learning rate for local SGD
# mu is the regularization constant for LR


def inexact_newton(F, T, alpha, M, K, R, lr, momentum=0, mu=1e-3, damp=True, gap=5, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test):

    indices = np.arange(20000)
    np.random.shuffle(indices)
    X_train = X_train[indices]
    Y_train = Y_train[indices]

    i = 0
    # matrix for storing all the newton weights
    W = np.zeros((T+1, d))

    # storing the losses for each newton step
    losses = []
    loss = F(X_test, Y_test, W[0].reshape((d, 1)), 0, order=5)
    losses.append(loss)
    # print(f"[*] Initial loss is given by {loss}")

    for t in range(T):
        # print(f"[*] In {t}-th iterate for inexact Newton.")

        u = np.asmatrix(W[t]).T

        Delta_t, i = localSGD(F, M, K, R, mu, u, lr,
                              momentum=momentum, forHVP=True, index=i)
        Delta_t = np.asmatrix(Delta_t[-1]).T

        # reverse scaling with Newton decrement
        hvpF = F(X_train, Y_train, u, mu, Delta_t, order=7, index=i-1)
        if damp:
            eta_t = alpha / \
                (1 +
                 np.sqrt(Delta_t.T.dot(hvpF).item(0)))
        else:
            eta_t = alpha
        # print(f"eta_t is {eta_t}")
        # updating with the obtained direction
        W[t+1] = W[t] - eta_t * Delta_t.T
        # print(W[t+1])

        if (t+1) % gap == 0:
            # store the loss
            loss = F(X_test, Y_test, W[t+1].reshape((d, 1)), 0, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    return W, losses

# local SGD algorithm which allows for using Hessian Vector products
# F is the objective function's stochastic oracle
# X, Y are the data matrices
# M (machines), K (local steps), R (communication rounds)
# are the local SGD parameters as usual
# u is the input for Q(v, u), i.e., the input for F
# lr is the constant learning rate for local SGD
# withHVP decides if a hessian vector product oracle has to be used


def localSGD(F, M, K, R, mu, u=None, lr=0.1, momentum=0, forHVP=True, gap=5, index=0, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test):

    indices = np.arange(20000)
    np.random.shuffle(indices)
    X_train = X_train[indices]
    Y_train = Y_train[indices]

    i = index
    # Matrix storing instantaneous weights on
    # each machine between communication rounds
    V = np.zeros((M, d))

    # Matrix storing all the weights after communication
    W = np.zeros((R + 1, d))

    if not forHVP:
        losses = []
        loss = F(X_test, Y_test, W[0].reshape((d, 1)), 0, order=5)
        losses.append(loss)
        # print(f"[*] Initial loss is given by {loss}")

    # loop for R communication rounds
    for r in range(R):
        # loop for simulating parallel computation on machines
        for m in range(M):
            # initialize the machine's local iterate at the latest
            # synchronized iterate
            V[m] = copy.copy(W[r])
            # loop for local steps on each machine
            for k in range(K):
                # computing the gradient with the HVP oracle
                if forHVP:
                    grad = quadraticWithHVP(
                        F, X_train, Y_train, u, mu, np.asmatrix(V[m]).T, order=1, index=i)
                else:
                    _, grad = F(X_train, Y_train, np.asmatrix(
                        V[m]).T, mu, order=1, index=i)
                i += 1
                # print(f"Norm of the gradient {np.linalg.norm(grad)}")

                # making the update on the machine
                if k == 0:
                    prev = copy.copy(V[m])
                    V[m] = V[m] - lr * grad.T
                else:
                    temp = copy.copy(V[m])
                    V[m] = V[m] - lr * grad.T + momentum * (V[m] - prev)
                    prev = copy.copy(temp)

        # averaging the iterates on all the machines at the end
        # of the communication round
        W[r+1] = np.mean(V, axis=0)
        # print(f"norm of outer iterate {np.linalg.norm(W[r+1])}")
        # print(f"norm of the full gradient {F(X, Y, np.asmatrix(W[r+1]))}")

        # storing the loss
        if (not forHVP) and ((r+1) % gap == 0):
            u = copy.copy(W[r+1].reshape((d, 1)))
            loss = F(X_test, Y_test, u, 0, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    if forHVP:
        return W, i
    else:
        return W, losses


# Federated Accelerated SGD from the paper by Yuan and Ma'20
# This is a parallel variant of Ghadimi and Lan's famous
# mini-max optimal algorithm
# F is the custom order stochastic oracle for the objective function
# X, Y are the data matrices
# M, K, R are the local SGD problem characterizers
# mu is the L2 regularization parameter
# alpha, beta, eta and gamma are the hyper-parameters for the optimzer
# ver decides if Fedac1 is used or fedac2 is used

def fedAC(F, M, K, R, mu, lr, ver=1, u=None, gap=5, i=0, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test):

    indices = np.arange(20000)
    np.random.shuffle(indices)
    X_train = X_train[indices]
    Y_train = Y_train[indices]

    gamma = max(np.sqrt(lr/(mu*K)), lr)

    if ver == 1:
        alpha = 1/(gamma*mu)
        beta = alpha + 1
    elif ver == 2:
        alpha = 1.5/(gamma*mu) - 0.5
        beta = (2 * alpha**2 - 1) / (alpha - 1)

    # matrices for different types of iterates
    W = np.zeros((M, d))
    W_ag = np.zeros((M, d))
    W_md = np.zeros((M, d))
    V = np.zeros((M, d))
    V_ag = np.zeros((M, d))
    W_avg = np.zeros((R+1, d))
    # W_avg[0] = np.random.randn(1, d)

    # Initiating the loss matrix and adding the initial loss
    # We store loss after every communication round and at the
    # initial point. Thus R + 1 sized vector.
    losses = []
    loss = F(X_test, Y_test, W_avg[0].reshape((d, 1)), 0, order=5)
    losses.append(loss)
    # print(f"[*] Initial loss is given by {loss}")

    for r in range(R):
        for m in range(M):
            for k in range(K):
                W_md[m] = W[m]/float(beta) + (1-1.0/beta)*W_ag[m]
                _, grad = F(X_train, Y_train, np.asmatrix(
                    W_md[m]).T, mu, order=1, index=i)
                V_ag[m] = W_md[m] - lr*grad.T
                V[m] = (1-1/alpha)*W[m] + W_md[m]/alpha - gamma*grad.T
                W[m] = V[m]
                W_ag[m] = V_ag[m]
                i += 1

        W = np.tile(np.mean(V, axis=0), (M, 1))
        W_ag = np.tile(np.mean(V_ag, axis=0), (M, 1))
        W_avg[r+1] = np.mean(W_ag, axis=0)

        if (r+1) % gap == 0:
            loss = F(X_test, Y_test, W_avg[r+1].reshape((d, 1)), 0, order=5)
            losses.append(loss)
            # print(f"[+] Loss incurred is {loss}")

    return W_avg, losses
