from utils import *


def vr_lgd(total_loss, g, alpha, d, K, T, m, pagg, omega, method="SAGA", psvrg=0, skip_it=1, track_agg=False, vr="full", target_acc=-1*np.inf, X0= None):

    F = np.zeros(1 + (K // skip_it))
    J = np.zeros((int(T*m), d))
    mean_J = np.zeros((T, d))
    Psi = np.zeros((T, d))
    if X0 is None:
        X = np.zeros((T, d))
    else:
        X = 1*X0
    n = int(T*m)
    F[0] = total_loss(X)

    aggregations = 0
    agg_fvals = []


    agg_fvals.append(F[0])

    k=0
    bigF = True

    last_toss = -1

    while (bigF and k<K):
        if np.random.rand() < pagg:
            #aggregation
            x_bar = x_mean(X)
            for t in range(T):
                # compute gradient & take step
                grad = psi_grad(X, x_bar, omega, t)
                X[t, :] = X[t, :] - alpha*(grad/pagg - (1/pagg-1)*Psi[t, :] + mean_J[t, :]/T)

                if vr == "full" or vr == "partial":
                    # update Jacobian table
                    Psi[t, :] = grad

            if last_toss == 1:
                aggregations += 1
            last_toss = 0

            if track_agg:
                agg_fvals.append(total_loss(X))

        else:
            # no aggregation
            for t in range(T):
                i = np.random.randint(m)
                ind = t*m + i
                grad = g(X[t, :], ind)

                X[t, :] = X[t, :] - alpha*(m/n/(1-pagg)*(grad - J[ind, :]) + mean_J[t, :]/T + Psi[t, :])

                # update Jacobian table
                if vr == "full":
                    if method == "SAGA":
                        mean_J[t, :] += (grad - J[ind, :])/m
                        J[ind, :] = grad
                    else:
                        if np.random.rand() < psvrg:
                            for ii in range(m):
                                grad = g(X[t, :], t*m+ii)
                                mean_J[t, :] += (grad - J[t*m+ii, :])/m
                                J[t*m+ii, :] = grad
            last_toss = 1
        # verbose
        if ((k+1) % skip_it) == 0:
            F[1 + (k // skip_it)] = total_loss(X)
            if F[1 + (k // skip_it)] < target_acc:
                bigF = False
            if ((k+1) % (skip_it)) == 0:
                print("Progress:{0:.3f}".format((k+1)/K))

        k += 1

    print("__________ FINISHED ______________")

    return F, X, (aggregations, agg_fvals)
