import numpy as np
from data.dataPreprocess import dataPreprocess
import argparse
import os, sys
import cvxpy as cp
from preSolve import solve_slater_point, solve_opt
import warnings


def valid_range_zero_one(value:float):
    fvalue = float(value)
    if 0 < fvalue < 1:
        return fvalue
    else:
        raise argparse.ArgumentTypeError("Value must be between 0 and 1 exclusive")
# --------------------------------------------------------------------------- #
# Parse command line arguments:
# --------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description='APDPro for ell_1 penalized personlized page rank problem')
parser.add_argument('--data', default='./data/bio-CE-LC.npz', type=str, help='path of data')
parser.add_argument('--tau', default=0.01, type=float, help='stepsize for x')
parser.add_argument('--sigma', default=0.01, type=float, help='stepsize for y')
parser.add_argument('--gamma', default=0.01, type=float, help='stepsize for mirror prox algorithm')
parser.add_argument('--alpha', default=0.4, type=valid_range_zero_one, help='A value between 0 and 1 for the alpha parameter')
parser.add_argument('--b', default=-0.05, type=float, help='upper bound of constraint')
parser.add_argument('--alg', default=2, type=int, help='choose the algorithm')
parser.add_argument('--K', default=5000, type=int, help='number of inner iteration to optimizing')
parser.add_argument('--print_freq', default=1000, type=int, help='frequency to print train stats')
parser.add_argument('--seed', default=0, type=int, help='seed for random number generators')
parser.add_argument('--out_fname', default='./res/bio-CE-LC-{}.csv', type=str, help='path of output file')
# --------------------------------------------------------------------------- #

class pageRank:
    def __init__(self, data, alpha, b) -> None:
        Q, alpha, s, Dinv_half, D_half = dataPreprocess(data, alpha)
        self.Q = Q
        self.alpha = alpha
        self.s = s
        self.Dinv_half = Dinv_half
        self.D_half = D_half
        self.absb = np.abs(b)
        self.Q = self.Q / np.abs(b)
        self.s = self.s / np.abs(b)
        self.b = np.sign(b)
        self.offset = self.alpha * self.Dinv_half @ self.s
        slaterPointPath = './slaterPoint/{}-Slater.npy'.format(os.path.basename(data).split('.')[0])
        self.xSlater = np.load(slaterPointPath)

        optPath = './cvxpySol/{}-Opt.npy'.format(os.path.basename(data).split('.')[0], alpha)
        self.xOpt = np.load(optPath)
        self.optval = self.obj(self.xOpt)
        self.xOpt[np.abs(self.xOpt) < 1e-8] = 0

    def constraint(self, x):
        return 0.5 * x.T @ self.Q @ x - self.alpha * self.s @ (self.Dinv_half @ x) - self.b
    
    def obj(self, x):
        return np.linalg.norm(self.D_half @ x, 1)

    def grad(self, x):
        # Q * x - alpha * Dinv_half @ s
        return self.Q @ x - self.offset
    
    def weight_soft_threshold(self, x, y, tau, xgrad = None):
        if xgrad is None:
            xgrad = x
        x = x - tau * y * self.grad(xgrad)
        threshold = np.diag(tau * self.D_half.toarray())
        assert np.max(x) is not None, 'x is None'
        return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)
    
    def guess_Dx(self):
        self.guess_condition()
        conSlater = self.constraint(self.xSlater)
        assert conSlater < 0, f'Slater point is not feasible, conSlater:{conSlater}'
        self.Dx = 2 * np.sqrt(-2 * self.constraint(self.xSlater) / self.mu) + 0.001
        return self.Dx
    
    def guess_Dy(self):
        self.Dy = np.array([self.obj(self.xSlater) / (-self.constraint(self.xSlater))])
        return self.Dy
        
    def guess_condition(self):
        eigvalue = np.linalg.eigvalsh(self.Q.toarray())
        self.mu = eigvalue[0]
        self.Lx = eigvalue[-1]
        return self.mu, self.Lx

    def guess_Delta(self, tau, sigma):
        self.Delta = (self.Dx ** 2)/(2 * tau) + (self.Dy ** 2)/(2 * sigma)
        return self.Delta
    
    def proj_y(self, y, rho):
        if y < rho / self.mu:
            y = rho / self.mu
        if y > self.Dy:
            y = self.Dy
        return y

    def update_rho(self, x, rho, beta):
        rho_temp = self.r * self.mu / (self.Lx * np.sqrt(2 * beta) + np.linalg.norm(self.grad(x), 2))
        rho  = np.maximum(rho_temp, rho)
        return rho
    
    def improve(self, x_mean, k, tau, sigma, rho):
        betaBar = (self.Dx ** 2 / (2 * tau) + self.Dy ** 2 / (2 * sigma)) / (k+1)
        gradnorm = np.linalg.norm(self.grad(x_mean))
        rhonew = np.power(self.Lx / self.r * np.sqrt(betaBar / (2* self.mu)) + np.sqrt(self.Lx ** 2 * betaBar / (2 * self.mu * self.r ** 2) + gradnorm/self.r), -2)
        return np.maximum(rhonew, rho)

    
    def identify_rate(self, x, xOpt):
        indice_x = np.where(x == 0)[0]
        indice_xOpt = np.where(xOpt == 0)[0]
        # iden_nnz = len(set(indice_x).intersection(set(indice_xOpt))) / len(indice_xOpt)
        indice_zero_x = np.where(x != 0)[0]
        indice_zero_xOpt = np.where(xOpt != 0)[0]
        # iden_zero = len(set(indice_zero_x).intersection(set(indice_zero_xOpt))) / len(indice_zero_xOpt)
        return (len(set(indice_x).intersection(set(indice_xOpt))) + len(set(indice_zero_x).intersection(set(indice_zero_xOpt)))) / len(x)
    

    def apdpro(self, tau, sigma, K, print_freq, out_fname):
        # args.alg == 1: APDPro
        ## Initialization
        alg = 'apdpro'
        self.guess_Dx()
        self.guess_Dy()
        self.guess_Delta(tau, sigma)
        self.r = np.amax(self.D_half.data)
        ## Main loop
        n = self.Q.shape[0]
        x = np.zeros(n)
        y = np.ones(1)
        # y = self.Dy / 2
        gamma = sigma / tau
        sigma0 = np.copy(sigma)
        tau0 = np.copy(tau)
        rho  = np.zeros(1)
        sigma_lag = sigma
        x_lag = np.copy(x)
        ## restart loop
        T = np.ceil(K // 20)
        for k in range(K):
            theta = sigma_lag / sigma
            sigma_lag = sigma
            z = (1 + theta) * self.constraint(x) - theta * self.constraint(x_lag)
            y = self.proj_y(y + sigma * z, rho)
            x_lag =  np.copy(x)
            x = self.weight_soft_threshold(x, y, tau)
            beta = sigma0 * self.Delta / gamma
            rho = self.update_rho(x, rho, beta)
            gamma = gamma * (1+rho * tau)
            tau = tau * np.sqrt(1/(1+rho * tau))
            sigma = tau * gamma
            if k % T == 0:
                tau = tau0
                sigma = sigma0
                gamma = sigma / tau
            if k % print_freq == 0:
                self.log(out_fname.format(alg), x, self.obj(x), self.constraint(x), y, \
                         np.count_nonzero(x), np.linalg.norm(x, 2),\
                        epoch = k, tau = tau, sigma = sigma,\
                        create=True if k ==0 else False, rho = rho) 

    def msapd(self, tau, sigma, K, print_freq, out_fname):
        # args.alg == 2: MSAPD
        ## Initialization
        alg = 'msapd'
        n = self.Q.shape[0]
        x = np.zeros(n)
        y = np.ones(1)
        self.guess_Dx()
        self.guess_Dy()
        self.guess_condition()
        self.guess_Delta(tau, sigma)
        self.r = np.amax(self.D_half.data)
        Tlist = self.msapd_freq(K)
        x_lag = np.copy(x)
        rho = 0
        ## Main loop
        count = 0
        for i, T in enumerate(Tlist):
            x_mean = np.copy(x)
            for k in range(int(T)):
                z = 2 * self.constraint(x) - self.constraint(x_lag)
                y = self.proj_y(y + sigma * z, 0)
                x = self.weight_soft_threshold(x, y, tau) 
                x_lag = np.copy(x)
                x_mean = (x_mean * k + x) / (k + 1)
                rho = self.improve(x_mean, k, tau, sigma, rho)
                if count % print_freq == 0:
                    self.log(out_fname.format(alg), x_mean, self.obj(x_mean), self.constraint(x_mean), y,\
                            np.count_nonzero(x_mean),  np.linalg.norm(x_mean, 2),\
                                epoch = count,\
                                tau = tau, sigma = sigma,\
                                create=True if count == 0 else False, rho = rho)
                count += 1
            tau = tau / np.sqrt(2)
            sigma = sigma * np.sqrt(2)

    
    def msapd_freq(self, K, period = 5):
        total = 0
        for i in range(1, period + 1):
            total += np.sqrt(2) ** i
        T = np.ceil(K / total)
        Tlist = [np.ceil(T * np.sqrt(2) ** i) for i in range(1, period + 1)]
        Tlist[-1] = K - sum(Tlist[:-1])
        return Tlist

    def mirror(self, gamma, K, print_freq, out_fname):
        # args.alg == 3: Mirror Descent
        alg = 'mirror'
        n = self.Q.shape[0]
        x = np.zeros(n)
        y = np.ones(1)
        self.guess_Dx()
        self.guess_Dy()
        self.guess_condition()
        x_mean = np.copy(x)
        for k in range(int(K)):
            x_half = self.weight_soft_threshold(x, y, gamma)
            y_half = self.proj_y(y + gamma * self.constraint(x), 0)

            x = self.weight_soft_threshold(x, y_half, gamma, xgrad=x_half)
            y = self.proj_y(y + gamma * self.constraint(x_half), 0)
            x_mean = (x_mean * k + x) / (k + 1)
            if k % print_freq == 0:
                self.log(out_fname.format(alg), x_mean, self.obj(x_mean), self.constraint(x_mean), y,\
                        np.count_nonzero(x_mean), np.linalg.norm(x_mean, 2),\
                        epoch = k, create=True if k ==0 else False)
                
        return x, y 


    def apd(self, tau, sigma, K, print_freq, out_fname, alg = None, x=None, y=None):
        # args.alg == 4: APD
        alg = 'apd'
        n = self.Q.shape[0]
        x = np.zeros(n)
        y = np.ones(1)
        self.guess_Dx()
        self.guess_Dy()
        x_lag = np.copy(x)
        x_mean = np.copy(x)
        for k in range(int(K)):
            z = 2 * self.constraint(x) - self.constraint(x_lag)
            y = self.proj_y(y + sigma * z, 0)
            x = self.weight_soft_threshold(x, y, tau)
            x_lag = np.copy(x)
            x_mean = (x_mean * k + x) / (k + 1)
            if k % print_freq == 0:
                self.log(out_fname.format(alg), x_mean, self.obj(x_mean), self.constraint(x_mean), y,\
                          np.count_nonzero(x_mean), np.linalg.norm(x_mean),\
                        epoch = k, create=True if k ==0 else False)
        return x, y
    
    def log(self, out_fname, x, obj, constr, dual_var, nnz, normx, epoch = 0, tau = None, sigma = None, create=False, rho = None):
        constr = constr * self.absb
        if create:
            with open(out_fname, 'w') as f:
                if tau != None:
                    if rho != None:
                        print('epoch,obj,constr,dual_var,nnz,norm(x),tau,sigma,relative_obj,identify_rate,x_gap,rho,rho_x_y', file=f)
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{tau:.12f},{sigma:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f},{rho:.12f},{rho_x_y:.12f}'.format(\
                            epoch=0, obj=float(obj),dual_var = dual_var[0], constr = constr, nnz = nnz,\
                            normx = normx, tau = float(tau), sigma = float(sigma), relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                            identify_rate = self.identify_rate(x, self.xOpt),\
                            x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt), rho = float(rho), rho_x_y = float(rho / self.mu)), file=f)
                    else:
                        print('epoch,obj,constr,dual_var,nnz,norm(x),tau,sigma,relative_obj,identify_rate,x_gap', file=f)
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{tau:.12f},{sigma:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f}'.format(\
                        epoch=0, obj=float(obj),dual_var = dual_var[0], constr = constr, nnz = nnz,\
                        normx = normx, tau = float(tau), sigma = float(sigma), relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                        identify_rate = self.identify_rate(x, self.xOpt),\
                        x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt)), file=f)
                else:
                    if rho != None:
                        print('epoch,obj,constr,dual_var,nnz,norm(x),relative_obj,identify_rate,x_gap,rho,rho_x_y', file=f)
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f},{rho:.12f},{rho_x_y:.12f}'.format(\
                            epoch=0, obj=float(obj),dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx,\
                            relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                            identify_rate = self.identify_rate(x, self.xOpt),\
                            x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt), rho = float(rho), rho_x_y = float(rho / self.mu)), file=f)
                    else:
                        print('epoch,obj,constr,dual_var,nnz,norm(x),relative_obj,identify_rate,x_gap', file=f)
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f}'.format(\
                        epoch=0, obj=float(obj),dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx,\
                            relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                            identify_rate = self.identify_rate(x, self.xOpt),\
                            x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt)), file=f)
        else:
            with open(out_fname, '+a') as f:
                if tau != None:
                    if rho != None:
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{tau:.12f},{sigma:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f},{rho:.12f},{rho_x_y:.12f}'.format(\
                        epoch=int(epoch), obj=float(obj), dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx, tau = float(tau),\
                        sigma = float(sigma),relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                        identify_rate = self.identify_rate(x, self.xOpt),\
                        x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt), rho = float(rho), rho_x_y = float(rho / self.mu)), file=f)
                    else:
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{tau:.12f},{sigma:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f}'.format(\
                        epoch=int(epoch), obj=float(obj), dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx, tau = float(tau),\
                        sigma = float(sigma),relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                        identify_rate = self.identify_rate(x, self.xOpt),\
                        x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt)), file=f)
                else:
                    if rho != None:
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f},{rho:.12f},{rho_x_y:.12f}'.format(\
                        epoch=int(epoch), obj=float(obj), dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx,\
                        relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                        identify_rate = self.identify_rate(x, self.xOpt),\
                        x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt), rho = float(rho), rho_x_y = float(rho / self.mu)), file=f)
                    else:
                        print('{epoch:d},{obj:.12f},{constr:.12f},{dual_var:.12f},{nnz:d},{normx:.12f},{relative_obj:.12f},{identify_rate:.12f},{x_gap:.12f}'.format(\
                        epoch=int(epoch), obj=float(obj), dual_var = dual_var[0], constr = constr, nnz = nnz, normx = normx,\
                        relative_obj = float(np.abs(obj - self.optval)/np.abs(self.optval)),\
                        identify_rate = self.identify_rate(x, self.xOpt),\
                        x_gap = np.linalg.norm(x - self.xOpt)/np.linalg.norm(self.xOpt)), file=f)
def main():
    args = parser.parse_args()
    model = pageRank(args.data, args.alpha, args.b)
    
    if args.alg == 1:
        model.apdpro(args.tau, args.sigma, args.K, args.print_freq, args.out_fname)
    elif args.alg == 2:
        model.msapd(args.tau, args.sigma, args.K, args.print_freq, args.out_fname)
    elif args.alg == 3:
        model.mirror(args.gamma, args.K, args.print_freq, args.out_fname)
    elif args.alg == 4:
        model.apd(args.tau, args.sigma, args.K, args.print_freq, args.out_fname)


if __name__ == "__main__":
    main()