import numpy as np
from numpy.linalg import norm as norm
import time


class ProjSplit:
    def __init__(self,prob,randomInit):
        self.prob = prob
        self.randomInit = randomInit

    def initialize(self,seed2use):
        np.random.seed(seed2use)

        if self.randomInit:
            initializeFunc = np.random.randn
        else:
            initializeFunc = np.zeros

        self.z = initializeFunc(self.prob.num_var)

        self.w1 = np.zeros(self.prob.num_var)
        self.w2 = np.zeros(self.prob.num_var)
        self.w3 = np.zeros(self.prob.num_var)


        self.x1 = initializeFunc(self.prob.num_var)
        self.x2 = initializeFunc(self.prob.num_var)
        self.x3 = initializeFunc(self.prob.num_var)

        self.x2 = self.prob.project_conePlusBall(self.x2)
        self.y2 = np.zeros(self.prob.num_var)

        xhat = self.prob.prox_L1(self.x3,1.0)
        self.y3 = self.x3 - xhat
        self.x3 = xhat


        np.random.seed()

    def getFixedStepsize(self,niter):
        rho = niter**(-1/4)
        alpha = "paper_const"
        return rho,alpha,0.0,0.0

    def run(self,niter,rho1,alpha1,rho_exp,alpha_exp,alpha_const_0,
            tau,gamma_pd,measure_freq,batchsz,reuseBatch,optimality_type,
            deterministic,seed2use,L=1e-15):

        if deterministic:
            batchsz = "full"
            rho_exp = 0.0
            alpha_exp = 0.0
            measure_freq = 1


        self.initialize(seed2use)
        self.OptCond = []
        self.get_optimality_criterion(optimality_type)
        tstartCalc = time.time()
        self.tstamp = [0]

        first = True
        rhoFirst = True
        for k in range(niter):

            vec = [rho1 * (k + 1) ** rho_exp,0.5*L**(-1)]
            ind = np.argmin(vec)
            rho = vec[ind]
            if (ind == 0) and rhoFirst:

                rhoFirst = False

            if alpha1 == "paper_const":
                alpha = alpha_const_0*rho**2

            else:
                vec = [alpha1 * (k + 1) ** alpha_exp,0.5*L**(-1)]
                ind = np.argmin(vec)
                alpha = vec[ind]
                if (ind == 0) and first:
                    first = False


            #two forward steps
            r = self.prob.getStochasticUpdate(batchsz,self.z,False)
            self.x1 = self.z - rho * (r - self.w1)
            self.y1 = self.prob.getStochasticUpdate(batchsz,self.x1,reuseBatch)

            # Cone and ball constraint
            t = self.z + tau * self.w2
            self.x2 = self.prob.project_conePlusBall(t)

            self.y2 = tau ** (-1) * (t - self.x2)

            # L1 regularizer
            t = self.z + tau * self.w3
            self.x3 = self.prob.prox_L1(t,tau)
            self.y3 = tau ** (-1) * (t - self.x3)

            if deterministic:
                self.do_projection(gamma_pd)
            else:

                # primal update
                self.z -= gamma_pd**(-1) * alpha * (self.y1 + self.y2 + self.y3)

                #dual update
                xbar = (self.x1 + self.x2 + self.x3) / 3.0
                self.w1 -= alpha * (self.x1 - xbar)
                self.w2 -= alpha * (self.x2 - xbar)
                self.w3 -= alpha * (self.x3 - xbar)


            if k%measure_freq == measure_freq-1:
                tendCalc = time.time()
                self.get_optimality_criterion(optimality_type)
                self.tstamp.append(self.tstamp[-1]+tendCalc - tstartCalc)
                tstartCalc = time.time()

    def do_projection(self,gamma_pd):
        u = self.y1 + self.y2 + self.y3
        xbar = (self.x1 + self.x2 + self.x3) / 3.0
        v1 = self.x1 - xbar
        v2 = self.x2 - xbar
        v3 = self.x3 - xbar
        tau = gamma_pd**(-1)*norm(u,2)**2 + norm(v1,2)**2 + norm(v2,2)**2 + norm(v3,2)**2
        if tau > 0:
            hplane = (self.z - self.x1).dot(self.y1 - self.w1)
            hplane += (self.z - self.x2).dot(self.y2 - self.w2)
            hplane += (self.z - self.x3).dot(self.y3 - self.w3)
            alpha = 1.0*(hplane>0)*hplane/tau
        else:
            alpha = 0.0

        # primal update
        self.z -= gamma_pd**(-1)*alpha * u

        # dual update
        self.w1 -= gamma_pd * alpha * v1
        self.w2 -= gamma_pd * alpha * v2
        self.w3 -= gamma_pd * alpha * v3



    def get_optimality_criterion(self,type):
        Az = self.prob.getStochasticUpdate("full",self.z,None)
        if type == "original":
            t = norm(Az - self.w1,2)**2
            t += norm(self.y2 - self.w2,2)**2 + norm(self.y3 - self.w3,2)**2
        else:
            t = norm(Az + self.y2 + self.y3, 2) ** 2 # old incorrect way

        t += norm(self.z - self.x2, 2) ** 2 + norm(self.z - self.x3, 2) ** 2

        self.OptCond.append(t)
