import numpy as np
from numpy.linalg import norm as norm
import time

class Results:
    def __init__(self):
        pass

def conjugate_resolvent(t,resolvent,stepsize):
    '''
    Suppose
    x = resolvent(t,step) returns
    x = (I+step*T)^{-1}(t)
    for a monotone operator T.
    This method returns
    y = (I+stepsize*T^{-1})^{-1}(t)
    using the Moreau decomposition
    '''
    x = resolvent(stepsize**(-1)*t,stepsize**(-1))
    return t - stepsize*x

def tseng_product(theFunc, prox1, prox2, eval_vec_field, init, iter=1000, alpha=1.0,
                  theta=0.99, stepIncrease=1.0, stepDecrease=0.7,usePast=False,
                  gamma1=1.0,gamma2=1.0,verbose=True,getFuncVals=True,historyFreq=1,print_intval=100,
                  batchsz="full",reuseBatch=False,doBT=True,Cstep=None,dstep=None,getAvResid=False):
    '''
    Tseng applied to the primal-dual product-space form
    this instance is applied to min_x f(x) + g(x) + h(x)
    Let p = (w_1,w_2,x), this opt prob is equivalent to finding 0\in B p + A p
    where Bp = [subf* w_1,subg* w_2] and B = [-x,-x,w_1+w_2+gradh x]
    note B is Lipschitz monotone but obvs not cocoercive
    So we use moreau's decomposition to evaluation proxfstar and proxgstar
    note that the linesearch exit condition must become: alpha||Axbar - Ax||_P <= delta||xbar-x||_{P^{-1}}
    Paper reference:
    Combettes, P.L., Pesquet, J.C.: Primal-dual splitting algorithm for solving inclusions
    with mixtures of composite, Lipschitzian, and parallel-sum type monotone operators.
    '''

    x = init.z
    w1 = init.w
    w2 = np.copy(x)

    if usePast:
        Apbar1 = -x
        Apbar2 = -x
        Apbar3 = w1 + w2 + eval_vec_field(batchsz, x, False)

    xav = np.zeros_like(x)
    w1av = np.zeros_like(w1)
    w2av = np.zeros_like(w2)


    Fx = []
    grad_evals = [0]
    times = [0]
    t0 = time.time()
    constraintErr = []
    tstartiter = time.time()
    residuals = []
    residAvs = []
    sum_of_steps = 0.0

    for k in range(iter):
        if dstep is not None:
            alpha = Cstep*(k+1)**(-dstep)
            sum_of_steps_OLD = sum_of_steps
            sum_of_steps += alpha

        if (k%print_intval==0) & verbose:
            print("iter: "+str(k))

        if usePast:
            Ap1 = Apbar1
            Ap2 = Apbar2
            Ap3 = Apbar3
        else:
            # compute Ap
            Ap1 = -x
            Ap2 = -x
            Ap3 = w1 + w2 + eval_vec_field(batchsz,x,False)

        newGrads = 1
        keepBT = True
        alpha = alpha * stepIncrease
        while keepBT:

            pbar = theBigProx(w1 - gamma1 * alpha * Ap1, w2 - gamma2 * alpha * Ap2,
                              x - alpha * Ap3, prox1,prox2,alpha,gamma1,gamma2)




            Apbar1 = -pbar[2]
            Apbar2 = -pbar[2]
            Apbar3 = pbar[0] + pbar[1] + eval_vec_field(batchsz,pbar[2],reuseBatch)

            if not doBT:
                break

            newGrads += 1
            totalNorm \
                = np.sqrt(gamma1*np.linalg.norm(Apbar1 - Ap1) ** 2 +
                          gamma2*np.linalg.norm(Apbar2 - Ap2) ** 2 +
                          np.linalg.norm(Apbar3 - Ap3) ** 2)
            totalNorm2 \
                = np.sqrt(gamma1**(-1)*np.linalg.norm(pbar[0] - w1) ** 2 +
                          gamma2**(-1)*np.linalg.norm(pbar[1] - w2) ** 2 +
                          np.linalg.norm(pbar[2] - x) ** 2)

            if (alpha * totalNorm <= theta * totalNorm2):
                keepBT = False
            else:
                alpha = stepDecrease * alpha

        w1old = w1
        w2old = w2
        xold = x

        w1 = pbar[0] - gamma1 * alpha * (Apbar1 - Ap1)
        w2 = pbar[1] - gamma2 * alpha * (Apbar2 - Ap2)
        x = pbar[2] - alpha * (Apbar3 - Ap3)

        if getAvResid:
            w1av = (sum_of_steps_OLD/sum_of_steps)*w1av + (alpha/sum_of_steps)*w1
            w2av = (sum_of_steps_OLD / sum_of_steps) * w2av + (alpha / sum_of_steps) * w2
            xav = (sum_of_steps_OLD / sum_of_steps) * xav + (alpha / sum_of_steps) * x



        if (k%historyFreq==0):
            #print("calculating resid")

            tenditer = time.time()

            if batchsz == "full":
                resid = norm(w1 - w1old,2)**2 + norm(w2 - w2old,2)**2 + norm(x - xold,2)**2
                resid = alpha**(-2)*resid
            else:
                # compute ApbarFull
                Apbar1Full = -pbar[2]
                Apbar2Full = -pbar[2]
                Apbar3Full = pbar[0] + pbar[1] + eval_vec_field("full", pbar[2], False)

                # compute ApFull
                Ap1Full = -xold
                Ap2Full = -xold
                Ap3Full = w1old + w2old + eval_vec_field("full", xold, False)

                resid = norm(alpha**(-1)*(w1old - pbar[0]) + Apbar1Full - Ap1Full)**2
                resid += norm(alpha ** (-1) * (w2old - pbar[1]) + Apbar2Full - Ap2Full) ** 2
                resid += norm(alpha ** (-1) * (xold - pbar[2]) + Apbar3Full - Ap3Full) ** 2

            residuals.append(resid)

            if getAvResid:
                # compute ApavFull
                Apav1Full = -xav
                Apav2Full = -xav
                Apav3Full = w1av + w2av + eval_vec_field("full", xav, False)
                # compute phat for av input
                alphaForAv = Cstep
                pbar_forAv = theBigProx(w1av - gamma1 * alphaForAv * Apav1Full, w2av - gamma2 * alphaForAv * Apav2Full,
                                  xav - alphaForAv * Apav3Full, prox1, prox2, alphaForAv, gamma1, gamma2)

                Apbar1_forAv = -pbar_forAv[2]
                Apbar2_forAv = -pbar_forAv[2]
                Apbar3_forAv = pbar_forAv[0] + pbar_forAv[1] + eval_vec_field("full", pbar_forAv[2], reuseBatch)

                w1plus_for_av = pbar_forAv[0] - gamma1 * alphaForAv * (Apbar1_forAv - Apav1Full)
                w2plus_for_av = pbar_forAv[1] - gamma2 * alphaForAv * (Apbar2_forAv - Apav2Full)
                xplus_for_av = pbar_forAv[2] - alphaForAv * (Apbar3_forAv - Apav3Full)

                residAv = norm(w1plus_for_av - w1av, 2) ** 2 + \
                          norm(w2plus_for_av - w2av, 2) ** 2 + norm(xplus_for_av - xav, 2) ** 2

                residAv = alphaForAv ** (-2) * residAv
                residAvs.append(residAv)



            if getFuncVals:
                Fx.append(theFunc(x))

            times.append(times[-1] + tenditer - tstartiter)
            tstartiter = time.time()


        grad_evals.append(grad_evals[-1]+newGrads)









    #print(f"Tseng final step {alpha}")
    out = Results()
    out.x = x
    out.f = Fx
    out.grad_evals = grad_evals[1:len(grad_evals)]
    out.times = np.array(times[1:len(times)])
    out.times = out.times - out.times[0]
    out.constraints = constraintErr
    out.residuals = residuals
    out.residAvs = residAvs

    if getFuncVals:
        out.finalFuncVal = theFunc(x)

    return out

def theBigProx(a, b, c, prox1,prox2,alpha,gamma1,gamma2):
    '''
        internal function for tseng_product() and for_reflect_back()
    '''
    out1 = conjugate_resolvent(a, prox1, gamma1 * alpha)
    out2 = conjugate_resolvent(b, prox2, gamma2 * alpha)
    out3 = c
    return [out1, out2, out3]






