import numpy as np
from numpy.linalg import norm as norm
import time

from Tseng import conjugate_resolvent

class Results:
    def __init__(self):
        pass



def frb(prox1, prox2, eval_vec_field, init, iter=1000, alpha=1.0,
                  delta=0.99, stepIncrease=1.0, stepDecrease=0.7,
                  verbose=False,historyFreq=1,
                  batchsz="full"):
    '''
    FRB applied to the primal-dual product-space form
    Paper reference:
    Combettes, P.L., Pesquet, J.C.: Primal-dual splitting algorithm for solving inclusions
    with mixtures of composite, Lipschitzian, and parallel-sum type monotone operators,
    Malitsky, Y., Tam, M.K.: A forward-backward splitting method for monotone
    inclusions without cocoercivity. SIAM Journal on Optimization 30(2), 1451–1472 (2020)
    '''

    x = init.z
    w1 = init.w
    w2 = np.copy(x)


    tstamps = []
    tstart = time.time()
    residuals = []

    # compute Ap
    Ap1 = -x
    Ap2 = -x
    Ap3 = w1 + w2 + eval_vec_field(batchsz, x, False)

    Ap1old = Ap1
    Ap2old = Ap2
    Ap3old = Ap3
    alphaOld = alpha

    for k in range(iter):
        if (k%100==0) & verbose:
            print("iter: "+str(k))

        doBackTrack = True
        alpha = alpha * stepIncrease
        while doBackTrack:
            toProx1 = w1 -  alpha * Ap1 -  alphaOld * (Ap1 - Ap1old)
            toProx2 = w2 -  alpha * Ap2 -  alphaOld * (Ap2 - Ap2old)
            toProx3 = x - alpha * Ap3 - alphaOld * (Ap3 - Ap3old)

            pbar = theBigProx(toProx1,toProx2,toProx3,prox1,prox2,alpha,1.0,1.0)

            Apbar1 = -pbar[2]
            Apbar2 = -pbar[2]
            Apbar3 = pbar[0] + pbar[1] + eval_vec_field(batchsz,pbar[2],False)

            normLeft = np.linalg.norm(Apbar1 - Ap1,2) ** 2 + np.linalg.norm(Apbar2 - Ap2,2) ** 2 \
                       + np.linalg.norm(Apbar3 - Ap3,2) ** 2
            normRight = np.linalg.norm(pbar[0] - w1,2) ** 2 \
                        + np.linalg.norm(pbar[1] - w2,2) ** 2 \
                        + np.linalg.norm(pbar[2] - x,2) ** 2

            if alpha * np.sqrt(normLeft) <= 0.5 * delta * np.sqrt(normRight):
                doBackTrack = False
            else:
                alpha = alpha * stepDecrease



        if k%historyFreq==0:
            tstamps.append(time.time() - tstart)
            t = norm(alpha**(-1)*(w1 - pbar[0]) + Ap1old - Ap1 + alphaOld*(Apbar1 - Ap1),2)**2
            t += norm(alpha**(-1)*(w2 - pbar[1]) + Ap2old - Ap2 + alphaOld*(Apbar2 - Ap2),2)**2
            t += norm(alpha ** (-1) * (x - pbar[2]) + Ap3old - Ap3 + alphaOld*(Apbar3 - Ap3), 2) ** 2
            residuals.append(t)


        Ap1old = Ap1
        Ap2old = Ap2
        Ap3old = Ap3

        Ap1 = Apbar1
        Ap2 = Apbar2
        Ap3 = Apbar3

        w1 = pbar[0]
        w2 = pbar[1]
        x = pbar[2]
        alphaOld = alpha

    print(f"FRB final step {alpha}")
    out = Results()
    out.x = x
    out.tstamps = np.array(tstamps) - tstamps[0]
    out.residuals = residuals

    return out

def theBigProx(a, b, c, prox1,prox2,alpha,gamma1,gamma2):
    '''
        internal function for tseng_product() and for_reflect_back()
    '''
    out1 = conjugate_resolvent(a, prox1, gamma1 * alpha)
    out2 = conjugate_resolvent(b, prox2, gamma2 * alpha)
    out3 = c
    return [out1, out2, out3]


def frb_var_reduced(prox1, prox2, eval_vec_field, init, probability,iter=1000, tau=1.0,
                    verbose=True,print_intval=100, batchsz="full",historyFreq=1):
    x = init.z
    w1 = init.w
    w2 = np.copy(x)

    # compute initial resid

    resids = [get_resid(w1,w2,x,1.0,eval_vec_field,prox1,prox2)]
    times = [0.0]

    tstartiter = time.time()

    # compute Full-back Fw
    Ffull_b1 = np.copy(-x)
    Ffull_b2 = np.copy(-x)
    Ffull_b3 = np.copy(w1 + w2 + eval_vec_field("full", x, False))



    anchor1 = np.copy(w1)
    anchor2 = np.copy(w2)
    anchor3 = np.copy(x)




    for k in range(iter):
        if (k%print_intval==0) & verbose:
            print("iter: "+str(k))
        if np.random.rand()<probability:
            # compute new full batch gradient
            Ffull_b1 = np.copy(-x)
            Ffull_b2 = np.copy(-x)
            Ffull_b3 = np.copy(w1 + w2 + eval_vec_field("full", x, False))


            anchor1 = np.copy(w1)
            anchor2 = np.copy(w2)
            anchor3 = np.copy(x)

        Fiz1 = -x
        Fiz2 = -x
        Fiz3 = w1 + w2 + eval_vec_field(batchsz, x, False)


        Fiw1 = -anchor3
        Fiw2 = -anchor3
        Fiw3 = anchor1 + anchor2 + eval_vec_field(batchsz, anchor3, True)

        zhat1 = w1 - tau*(Ffull_b1 + Fiz1 - Fiw1)
        zhat2 = w2 - tau*(Ffull_b2 + Fiz2 - Fiw2)
        zhat3 = x - tau*(Ffull_b3 + Fiz3 - Fiw3)

        w1,w2,x = theBigProx(zhat1, zhat2, zhat3, prox1, prox2, tau, 1.0, 1.0)




        if (k % historyFreq == 0):

            tenditer = time.time()
            residsq = get_resid_fbf_var(w1,w2,x,tau,eval_vec_field,zhat1,zhat2,zhat3)



            resids.append(residsq)
            times.append(times[-1] + tenditer - tstartiter)
            tstartiter = time.time()


    times = np.array(times)
    resids = np.array(resids)
    return times,resids


def get_resid_fbf_var(w1,w2,x,tau,eval_vec_field,zhat1,zhat2,zhat3):
    Bz1 = -x
    Bz2 = -x
    Bz3 = w1 + w2 + eval_vec_field("full", x, False)

    resid1 = tau ** (-1) * (zhat1 - w1) + Bz1
    resid2 = tau ** (-1) * (zhat2 - w2) + Bz2
    resid3 = tau ** (-1) * (zhat3 - x) + Bz3

    residsq = np.linalg.norm(resid1) ** 2 + \
              np.linalg.norm(resid2) ** 2 + \
              np.linalg.norm(resid3) ** 2
    return residsq

def get_resid(w1,w2,x,tau,eval_vec_field,prox1,prox2):
    # compute the residual for frb var red
    # using a normal Tseng step to ensure consistency.
    # only used in the first iteration.
    # After that we use get_resid_fbf_var.


    Ap1 = -x
    Ap2 = -x
    Ap3 = w1 + w2 + eval_vec_field("full", x, False)
    pbar = theBigProx(w1 - tau * Ap1, w2 - tau * Ap2,
                      x - tau * Ap3, prox1, prox2, tau, 1.0, 1.0)

    Apbar1 = -pbar[2]
    Apbar2 = -pbar[2]
    Apbar3 = pbar[0] + pbar[1] + eval_vec_field("full", pbar[2], False)

    w1new = pbar[0] -  tau * (Apbar1 - Ap1)
    w2new = pbar[1] -  tau * (Apbar2 - Ap2)
    xnew = pbar[2] - tau * (Apbar3 - Ap3)

    resid = norm(w1new - w1, 2) ** 2 + norm(w2new - w2, 2) ** 2 + norm(xnew - x, 2) ** 2
    resid = tau ** (-2) * resid
    return resid








