from utils import *





def dist_GD(loss, g, blocks, A, mu, vr, importance, K, skip_it, tau, simple_spars, acc=False,  x0=None, pre_text=''):

    n, d = A.shape
    C = get_Cis(A, blocks)


    nmu = len(blocks)*mu

    probs, alpha, L, LLL = get_alpha_probs(A, blocks, mu, importance, tau, simple_spars, nmu)

    print("alpha: {}, L: {}, LLL: {}".format(alpha, L, LLL))



    F = np.zeros(1 + (K // skip_it))

    if x0 is None:
        x = np.zeros(d)
        z = np.zeros(d)
        w = np.zeros(d)
        y = np.zeros(d)
    else:
        x = 1*x0
        z = 1*x0
        w = 1*x0
        y = 1*x0

    H = np.zeros((len(blocks), d))

    if vr > 0:
        for i, bl in enumerate(blocks):
            H[i] = full_grad(g, x, range(bl[0], bl[1]))

    F[0] = loss(x)
    print(pre_text + " Progress:{0:.3f}".format(0) + "Val {} ".format(F[0]))

    k=0
    bigF = True

    #Acceleration params
    rho = min(1.0, np.min(probs)*(max(1.0, np.sqrt(L/(LLL+0.001*L)))))
    eta = min(1/L/2, 1/(LLL+0.00001*L)*(np.min(probs)**2)/(rho*rho))
    teta1 = min(0.25, np.sqrt(eta*mu/rho))
    teta2 = 0.5
    gamma = eta/(2*(teta1+eta*mu))
    beta = 1-gamma*mu



    while (bigF and k<K):
        gr = np.zeros(d)


        if acc:
            x = teta1*z + teta2*w + (1-teta1 - teta2)*y
            for i, bl in enumerate(blocks):
                grad1 = full_grad(g, x, range(bl[0], bl[1])) - H[i]
                grad2 = full_grad(g, w, range(bl[0], bl[1])) - H[i]
                spars_grad, _ = spars_step(C[i], grad1, mu,  probs[i], simple_spars)
                spars_grad_H, spars_update_H = spars_step(C[i], grad2, mu,  probs[i], simple_spars)
                gr += (spars_grad + H[i])/len(blocks)
                if vr == 1:
                    spars_grad_H*np.min(probs)
                elif vr == 2:
                    H[i] += spars_update_H
            y = x - eta*gr
            z = beta*z+(1-beta)*x + gamma/eta*(y-x)
            if np.random.rand() < rho:
                w = y*1.0
        else:
            for i, bl in enumerate(blocks):
                grad = full_grad(g, x, range(bl[0], bl[1])) - H[i]
                spars_grad, spars_update_H = spars_step(C[i], grad, mu,  probs[i], simple_spars)
                gr += (spars_grad + H[i])/len(blocks)
                if vr == 1:
                    H[i] += spars_grad*np.min(probs)
                elif vr == 2:
                    H[i] += spars_update_H
            x = x - alpha*gr


        # verbose
        if ((k+1) % skip_it) == 0:
            F[1 + (k // skip_it)] = loss(x)

            if ((k+1) % (10*skip_it)) == 0:
                print(pre_text + " Progress:{0:.3f}".format((k+1)/K) + "Val {} ".format(F[1 + (k // skip_it)]))

        k += 1

    print("__________ FINISHED ______________")

    return F, x





def dist_GD_quad(A, alpha, K, skip_it, simple_spars, imp=False, x0=None, pre_text=''):

    n, d = A.shape


    if imp:
        PROBS = A*A/np.sum(A*A)*n
    else:
        PROBS = np.ones(n,d) / d


    F = np.zeros(1 + (K // skip_it))

    x = 1*x0


    F[0] = sqnorm(np.dot(A,x))/2

    print(pre_text + " Progress:{0:.3f}".format(0) + "Val {} ".format(F[0]))

    k=0
    bigF = True




    while (bigF and k<K):
        gr = np.zeros(d)
        for i in range(n):
            sv = np.zeros(d)
            grad = A[i,:]*np.dot(A[i,:], x)
            j = np.random.choice(d, None, False, PROBS[i] / np.sum(PROBS[i]))
            sv[j] = 1.0/PROBS[i][j]
            if simple_spars:
                gr += grad * sv
            else:
                gr += grad /PROBS[i][j] * A[i,j] * A[i,j]/sqnorm(A[i,:])
        x = x - alpha*gr


        # verbose
        if ((k+1) % skip_it) == 0:
            F[1 + (k // skip_it)] = sqnorm(np.dot(A,x))/2

            if ((k+1) % (10*skip_it)) == 0:
                print(pre_text + " Progress:{0:.3f}".format((k+1)/K) + "Val {} ".format(F[1 + (k // skip_it)]))

        k += 1

    print("__________ FINISHED ______________")

    return F, x



def find_xstar(g, mu, L, d, n, K, pre_text='', loss = None, x0 = None):
    if x0 is None:
        x = np.zeros(d)
        y = np.zeros(d)
    else:
        x = x0*1.0
        y = x0*1.0

    beta = (np.sqrt(L/mu)-1)/(np.sqrt(L/mu)+1)
    for i in range(K):
        xn = y - full_grad(g, y, np.asarray(range(n)))/L
        y = xn - beta*(xn-x)
        x =  xn*1 #x - full_grad(g, x, np.asarray(range(n)))/L #xn*1

        if ((i+1) % (K//10)) == 0:
            print(pre_text +" Progress:{0:.2f}".format((i+1)/K) + "   Loss:{}".format(loss(x)))

    return x

