
import numpy as np
import scipy.special
from scipy.special import logsumexp, softmax
import random
from time import time

import os
# os.chdir('/.../code')


### Define Weight Sequences
def WeightPower(t, power = 1):
    return (1-1/(t+1))**power

def WeightLogPower(t, power = 1):
    return ((1-1/(t+1))**(power*np.log(t+2)))*(1/(t+1)**(power*np.log(1+1/(t+2))))

### Define Oracle Noises
def Gaussian_Noise(matrix, d, scale = 1):
    return matrix + np.random.normal(0,scale,(d,d))

def Exp_Noise(matrix, d, scale = 1):
    return matrix + np.random.exponential(scale,(d,d)) - scale 

### Define Sketching/Subsampling functions
def Gaussian_Sketch(n,matrix,sketch_size,nnz=None):
    S = np.random.randn(sketch_size,n)
    SA = S @ matrix
    return SA.T@SA/sketch_size

def Sub_Sampling(n,matrix,sub_size,nnz=None):
    Id_Sub_Set = np.zeros((n,1))
    Id_Sub_Set[np.random.choice(n,sub_size,replace=False)] = 1.0
    SA = Id_Sub_Set*matrix
    return SA.T@SA/sub_size*n

def Sparse_Sketch(n,matrix,sketch_size,nnz=None):
    S = np.zeros((sketch_size,n))
    S[np.random.choice(sketch_size,n),np.arange(n)]=np.random.choice(np.array([-1,1], dtype=np.float64), size=n)
    SA = S @ matrix
    return SA.T@SA

def SparRad_Sketch(n,matrix,sketch_size,nnz=None):
    if nnz is None:
        nnz = 0.1
    d_tilde = int(nnz*matrix.shape[1])
    row_index = np.repeat(np.arange(sketch_size),d_tilde)
    column_index = np.random.choice(n,sketch_size*d_tilde)
    values = np.random.choice(np.array([-1,1],dtype=np.float64),sketch_size*d_tilde)
    S = np.zeros((sketch_size,n))
    S[row_index,column_index] = values
    SA = S @ matrix
    return SA.T@SA*n/(sketch_size*d_tilde)
    

Weight = {'power':WeightPower, 'log_power':WeightLogPower}

Sto_Hess = {'gaussian_noise':Gaussian_Noise, 'exp_noise':Exp_Noise}

Sketch_Func = {'Gaussian':Gaussian_Sketch, 'CountSketch':Sparse_Sketch,\
               'Subsampled':Sub_Sampling,'LESS-uniform':SparRad_Sketch}


### Data Generating
class DataGenerate_HighCond:
    def __init__(self, n, d, lambd, kap=1., Rep=10):
        self.lambd, self.Rep = lambd, Rep
        self.IdCond, self.IdReal = 'true', 'Unreal'
        self.kap = kap
        # generate data
        np.random.seed(2022)
        U, _, _ = np.linalg.svd(np.random.randn(n,d),full_matrices=False)
        Sigma = np.array([j for j in np.linspace(1,d**kap,d)])
        self.Dat = U@np.diag(Sigma)
        x_under = 1./np.sqrt(d)*np.random.randn(d,1)
        Prob = scipy.special.expit(self.Dat@x_under)

        # Prob = 1./(1+np.exp(-self.Dat@x_under))
        self.Resp = 2*np.random.binomial(1, p=Prob)-1        

class DataGenerate_HighCoher:
    def __init__(self, n, d, lambd, kap=1., Rep=10):
        self.lambd, self.Rep = lambd, Rep
        self.IdCond, self.IdReal = 'false', 'Unreal'
        self.kap = kap
        # generate data
        np.random.seed(2022)
        g = np.tile(np.random.gamma(1/2,2,n),(d,1)).T
        U, _, _ = np.linalg.svd(np.random.randn(n,d)/np.sqrt(g), full_matrices=False)
        Sigma = np.array([j for j in np.linspace(1,d**kap,d)])
        self.Dat = U@np.diag(Sigma)
        x_under = 1./np.sqrt(d)*np.random.randn(d,1)
        Prob = 1./(1+np.exp(-self.Dat@x_under))
        self.Resp = 2*np.random.binomial(1, p=Prob)-1
       

### Problem Solver
        
def logsig(x):
    """
    Compute the log-sigmoid function component-wise.
    See http://fa.bianp.net/blog/2019/evaluate_logistic/ for more details.
    """
    out = np.zeros_like(x)
    idx0 = x < -33
    out[idx0] = x[idx0]
    idx1 = (x >= -33) & (x < -18)
    out[idx1] = x[idx1] - np.exp(x[idx1])
    idx2 = (x >= -18) & (x < 37)
    out[idx2] = -np.log1p(np.exp(-x[idx2]))
    idx3 = x >= 37
    out[idx3] = -np.exp(-x[idx3])
    return out

class LogisticRegression:    
    def __init__(self, A, b, lambd):
        self.A, self.b, self.lambd = A, b, lambd
        self.n, self.d = A.shape
        np.random.seed(2022)
        random.seed(2022)
        self.x_0 = 1./np.sqrt(self.d)*np.random.randn(self.d,1)
#        self.x_0 = 0*np.ones((self.d,1))
        
    def logistic_loss(self, x):

        return np.mean(-logsig(self.b*(self.A@x)))+self.lambd/2*(x**2).sum()
        # return np.log(1+np.exp(-self.b*self.A@x)).mean()+self.lambd/2*(x**2).sum()
        
    def grad(self, x):

        activation = scipy.special.expit(-self.b*(self.A@x))
        return -1./self.n*self.A.T@(self.b*activation)+self.lambd*x

        # return -1./self.n*self.A.T@(self.b*1./(1+np.exp(self.b*self.A@x)))+self.lambd*x
        
    def Hess(self, x):
        activation = scipy.special.expit(-self.b*(self.A@x))
        D = activation * (1-activation)/self.n
        # v = np.exp(self.b*self.A@x)
        # D = (v/(1+v)**2)/self.n
        return self.A.T@(D*self.A)+self.lambd*np.identity(self.d)

    def sqrt_hess(self, x):
        activation = scipy.special.expit(-self.b*(self.A@x))
        D = np.sqrt(activation * (1-activation)/self.n)
        # v = np.exp(self.b*self.A@x)
        # D = np.sqrt(v)/(1+v)/np.sqrt(self.n)
        return D*self.A

    def line_search(self, x, f_x, NewDir, Del, beta=0.3, rho=0.8, mu=1.):
        # mu = 1
        x_1 = x + mu*NewDir
        while self.logistic_loss(x_1) > f_x + beta*mu*Del:
            mu = mu*rho
            x_1 = x + mu*NewDir
        return mu

    def solve_exactly(self, Max_Iter=10**3, EPS=1e-10):
        # use Newton method to solve exactly
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        while eps >= EPS and t <= Max_Iter:
            Hess_x_0 = self.Hess(x_0)
            NewDir = -np.linalg.inv(Hess_x_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum()
            Alp = self.line_search(x_0,self.logistic_loss(x_0),NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
        self.x_true = x_0 
        self.Hess_x_true = self.Hess(x_0)
        self.loss_true = self.logistic_loss(x_0)
        return self.x_true, self.Hess_x_true, self.loss_true
    

    def GD(self,Max_Iter=10**3,EPS = 1e-8):
        # implement gradient descent
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)

        start = time()
        Alp = 1
        while eps>=EPS and t<= Max_Iter:
            NewDir = -grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner,mu=Alp/0.8)
            s = Alp*NewDir
            x_0 = x_0 + s
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.logistic_loss(x_0))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)

        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err_g, Err_x, Losses - self.loss_true, Time
    

    def AGD(self,Max_Iter=10**3,EPS = 1e-8):
        # implement accelerated gradient descent
        # Strongly convex FISTA, see Acceleration Methods
        Xarray, Losses = [], []
        Err_g = []

        x = self.x_0 
        Xarray.append(x)
        Losses.append(self.logistic_loss(x))
        z = x
        # grad_y = self.grad(y)
        eps, t = np.linalg.norm(self.grad(x)), 0
        Err_g.append(eps)

        start = time()
        L = 1
        A = 0

        while eps>=EPS and t<= Max_Iter:
            q = self.lambd/L
            A_new = (2*A+1+np.sqrt(4*A+4*q*A**2+1))/(2*(1-q))
            tau = (A_new-A)*(1+q*A)/(A_new+2*q*A_new*A-q*A**2)
            delta = (A_new-A)/(1+q*A_new)
            y = x + tau*(z-x)

            grad_y = self.grad(y)
            NewDir = -grad_y
            Inner = (grad_y*NewDir).sum() 
            x_new = y - grad_y/L

            if self.logistic_loss(x_new) <= self.logistic_loss(y) + 0.3*1/L*Inner:
                x = x_new
                z = (1-q*delta)*z + q*delta*y + delta*(x-y)
                A = A_new
                eps, t = np.linalg.norm(self.grad(x)), t+1
                Xarray.append(x)
                Losses.append(self.logistic_loss(x))
                Err_g.append(eps)
            else:
                L = L/0.8

            # Alp = self.line_search(y,self.logistic_loss(y),NewDir,Inner,mu= Alp/(0.8))
            # s = Alp*NewDir
            # x_new = y + s
            # c_new = (1 + np.sqrt(1+4*c**2))/2
            # y = x_new + (c-1)/(c_new)*(x_new-x)
            
            # x = x_new
            # c = c_new
            # grad_y = self.grad(y)

            # eps, t = min(np.linalg.norm(grad_y),np.linalg.norm(self.grad(x))), t+1
            # eps, t = np.linalg.norm(self.grad(x)), t+1
            # Xarray.append(x)
            # Losses.append(self.logistic_loss(x))
            # Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err_g, Err_x, Losses - self.loss_true, Time

    def BFGS(self,Max_Iter=10**3,EPS = 1e-8):
        # implement BFGS
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        B_inv = np.identity(self.d)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()
        while eps>=EPS and t<= Max_Iter:
            NewDir = -B_inv@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            s = Alp*NewDir
            x_0 = x_0 + s
            grad_x_0_ = self.grad(x_0)
            y = grad_x_0_ - grad_x_0 
            grad_x_0 = grad_x_0_.copy()
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.logistic_loss(x_0))
            Err_g.append(eps)
            # update B
            sy_inner, sy_outer, ss_outer = (s*y).sum(), s@y.T, s@s.T
            B_1 = (sy_inner+(y*(B_inv@y)).sum())/sy_inner**2 * ss_outer
            b_2 = B_inv@sy_outer.T
            B_2 = (b_2+b_2.T)/sy_inner
            B_inv = B_inv + B_1 - B_2
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err_g, Err_x, Losses - self.loss_true, Time   

    # def sto_oracle_Newton(self,ora_set='gaussian_noise',scale=0,Max_Iter=10**3,EPS=1e-8):
    #     # implement weighted stochastic Newton (oracle noise)
    #     Xarray, Losses = [], []
    #     x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
    #     eps, t = np.linalg.norm(grad_x_0), 0
    #     Xarray.append(x_0)
    #     Losses.append(self.logistic_loss(x_0))
        
    #     start = time()
    #     while eps>=EPS and t<=Max_Iter:
    #         H_hat_x_0 = Sto_Hess[ora_set](self.Hess(x_0), self.d, scale)
    #         if scale == 0:
    #             NewDir = -np.linalg.inv(H_hat_x_0)@grad_x_0
    #             Inner = (grad_x_0*NewDir).sum()
    #         else:
    #             if np.linalg.det(H_hat_x_0)!=0:
    #                 NewDir = -np.linalg.inv(H_hat_x_0)@grad_x_0
    #                 Inner = (grad_x_0*NewDir).sum()
    #                 if Inner > 0:
    #                     NewDir = -grad_x_0.copy()
    #                     Inner = (grad_x_0*NewDir).sum()
    #             else:
    #                 NewDir = -grad_x_0.copy()
    #                 Inner = (grad_x_0*NewDir).sum()
    #         Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
    #         x_0 = x_0 + Alp*NewDir
    #         grad_x_0 = self.grad(x_0)
    #         eps, t = np.linalg.norm(grad_x_0), t+1
    #         Xarray.append(x_0)
    #         Losses.append(self.logistic_loss(x_0))
    #     Time = time()-start
    #     Xarray = np.hstack(Xarray)-self.x_true
    #     Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
    #     return Err, Losses, Time, Xarray

    # def sto_weight_oracle_Newton(self,wei_set='power',power=1,ora_set='gaussian_noise',scale=0,Max_Iter=10**3,EPS=1e-8):
    #     # implement weighted stochastic Newton (oracle noise)
    #     Xarray, Losses = [], []
    #     x_0, grad_x_0, w_H_0 = self.x_0, self.grad(self.x_0), np.identity(self.d)
    #     eps, t = np.linalg.norm(grad_x_0), 0
    #     Xarray.append(x_0)
    #     Losses.append(self.logistic_loss(x_0))
        
    #     start = time()
    #     while eps>=EPS and t<=Max_Iter:
    #         H_hat_x_0 = Sto_Hess[ora_set](self.Hess(x_0), self.d, scale)
    #         ratio = Weight[wei_set](t,power)
    #         w_H_0 = ratio*w_H_0 + (1-ratio)*H_hat_x_0
    #         if scale == 0:
    #             NewDir = -np.linalg.inv(w_H_0)@grad_x_0
    #             Inner = (grad_x_0*NewDir).sum()
    #         else:
    #             if np.linalg.det(w_H_0)!=0:
    #                 NewDir = -np.linalg.inv(w_H_0)@grad_x_0
    #                 Inner = (grad_x_0*NewDir).sum()
    #                 if Inner > 0:
    #                     NewDir = -grad_x_0.copy()
    #                     Inner = (grad_x_0*NewDir).sum()
    #             else:
    #                 NewDir = -grad_x_0.copy()
    #                 Inner = (grad_x_0*NewDir).sum()
    #         Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
    #         x_0 = x_0 + Alp*NewDir
    #         grad_x_0 = self.grad(x_0)
    #         eps, t = np.linalg.norm(grad_x_0), t+1
    #         Xarray.append(x_0)
    #         Losses.append(self.logistic_loss(x_0))
    #     Time = time()-start
    #     Xarray = np.hstack(Xarray)-self.x_true
    #     Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
    #     return Err, Losses, Time, Xarray

    def sketch_Newton(self,sketch_size,sketch_method='Gaussian',nnz=None,Max_Iter=10**3,EPS=1e-8):
        # implement stochastic Newton (sketching/subsampling)
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()                
        while eps>=EPS and t<=Max_Iter:
            H_hat_x_0 = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x_0),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            NewDir = -np.linalg.inv(H_hat_x_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1

            Xarray.append(x_0)
            Losses.append(self.logistic_loss(x_0))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err_g, Err_x, Losses - self.loss_true, Time
    
    def sto_weight_Sket_Newton(self,sketch_size,wei_set='power',power=1,sketch_method='Gaussian',nnz=None,Max_Iter=10**3,EPS=1e-8):
        # implement weighted stochastic Newton (sketching/subsampling)
        Xarray, Losses = [], []
        Err_g = []
        
        x_0, grad_x_0, w_H_0 = self.x_0, self.grad(self.x_0), np.identity(self.d)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()
        while eps>=EPS and t<=Max_Iter:
            H_hat_x_0 = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x_0),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            ratio = Weight[wei_set](t,power)
            w_H_0 = ratio*w_H_0 + (1-ratio)*H_hat_x_0
            NewDir = -np.linalg.inv(w_H_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum()
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.logistic_loss(x_0))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # the unweighted norm
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        return Err_g, Err_x, Losses - self.loss_true, Time
       



    ###############################################################################################################################
    # Stochastic Newton Proximal Extragradient (no Hessian averaging)
    def sketch_NPE(self, sketch_size, sketch_method='Gaussian',nnz=None, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()

        x = x_0
        grad_x = grad_x_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            eta = sigma
            x_hat = x - np.linalg.inv(H_hat_x + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(H_hat_x + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            # x = x - eta*grad_x_hat
            # x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = min(np.linalg.norm(grad_x),np.linalg.norm(grad_x_hat)), t+1
            Xarray.append(x)
            Losses.append(self.logistic_loss(x))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        return Err_g, Err_x, Losses - self.loss_true, Time

    
    # Stochastic Newton Proximal Extragradient with Hessian averaging
    def sto_weight_Sket_NPE(self, sketch_size, wei_set='power', power=1, sketch_method='Gaussian',nnz=None, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0, w_H_0 = self.x_0, self.grad(self.x_0), np.identity(self.d)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()

        x = x_0
        grad_x = grad_x_0
        w_H = w_H_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            ratio = Weight[wei_set](t,power)
            w_H = ratio*w_H + (1-ratio)*H_hat_x
            eta = sigma
            x_hat = x - np.linalg.inv(w_H + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(w_H + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            # x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            # x = x - eta*grad_x_hat
            x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = min(np.linalg.norm(grad_x),np.linalg.norm(grad_x_hat)), t+1
            Xarray.append(x)
            Losses.append(self.logistic_loss(x))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        return Err_g, Err_x, Losses - self.loss_true, Time
    


    # Newton

    def Newton(self,Max_Iter=10**3,EPS=1e-8):
        # implement Newton 
        Xarray, Losses = [], []
        Err_g = []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)
        
        start = time()                
        while eps>=EPS and t<=Max_Iter:
            H_hat_x_0 = self.Hess(x_0)
            NewDir = -np.linalg.inv(H_hat_x_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.logistic_loss(x_0))
            Err_g.append(eps)

        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        
        return Err_g, Err_x, Losses - self.loss_true, Time


    # Newton Proximal Extragradient
    def NPE(self, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []
        Err_g = []
        
        x_0, grad_x_0, H_0 = self.x_0, self.grad(self.x_0), self.Hess(self.x_0)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.logistic_loss(x_0))
        Err_g.append(eps)

        start = time()

        x = x_0
        grad_x = grad_x_0
        H = H_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            # H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            # ratio = Weight[wei_set](t,power)
            # w_H = ratio*w_H + (1-ratio)*H_hat_x
            H = self.Hess(x)
            eta = sigma
            x_hat = x - np.linalg.inv(H + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(H + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            # x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            x = x - eta*grad_x_hat
            x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = min(np.linalg.norm(grad_x),np.linalg.norm(grad_x_hat)), t+1
            Xarray.append(x)
            Losses.append(self.logistic_loss(x))
            Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)

        return Err_g, Err_x, Losses - self.loss_true, Time




# Reference: https://github.com/konstmish/opt_methods
class LogSumExp:    
    def __init__(self, A, b, lambd, max_smoothing = 1.):
        self.A, self.b, self.lambd = A, b, lambd
        self.n, self.d = A.shape
        self.max_smoothing = max_smoothing
        np.random.seed(2022)
        random.seed(2022)
        self.x_0 = 1./np.sqrt(self.d)*np.random.randn(self.d,1)
#        self.x_0 = 0*np.ones((self.d,1))
        
    def loss(self, x):
        return self.max_smoothing*logsumexp((self.A@x - self.b)/self.max_smoothing) + self.lambd/2*(x**2).sum()
    # np.log(1+np.exp(-self.b*self.A@x)).mean()+self.lambd/2*(x**2).sum()
        
    def grad(self, x):
        return self.A.T @ softmax((self.A@x - self.b)/self.max_smoothing) + self.lambd*x
        # return -1./self.n*self.A.T@(self.b*1./(1+np.exp(self.b*self.A@x)))+self.lambd*x
        
    def Hess(self, x):
        softmax_A = softmax((self.A@x - self.b)/self.max_smoothing)
        hess1 = self.A.T * (softmax_A.squeeze()/self.max_smoothing) @ self.A
        grad =  self.A.T @ softmax_A
        hess2 = -np.outer(grad,grad) / self.max_smoothing
        return hess1 + hess2 + self.lambd*np.identity(self.d)
        # v = np.exp(self.b*self.A@x)
        # D = (v/(1+v)**2)/self.n
        # return self.A.T@(D*self.A)+self.lambd*np.identity(self.d)

    def sqrt_hess(self, x):
        softmax_A = softmax((self.A@x - self.b)/self.max_smoothing)
        # See https://scicomp.stackexchange.com/questions/33469/computing-square-root-of-diagu-uu
        # we compute a decomposition of diag(v) - v*v^\top
        return (np.sqrt(softmax_A) * self.A - np.outer(np.sqrt(softmax_A), self.A.T @ softmax_A))/np.sqrt(self.max_smoothing)
        # v = np.exp(self.b*self.A@x)
        # D = np.sqrt(v)/(1+v)/np.sqrt(self.n)
        # return D*self.A

    def line_search(self, x, f_x, NewDir, Del, beta=0.3, rho=0.8):
        mu = 1
        x_1 = x + mu*NewDir
        while self.loss(x_1) > f_x + beta*mu*Del:
            mu = mu*rho
            x_1 = x + mu*NewDir
        return mu

    def solve_exactly(self, Max_Iter=10**3, EPS=1e-10):
        # use Newton method to solve exactly
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        while eps >= EPS and t <= Max_Iter:
            Hess_x_0 = self.Hess(x_0)
            NewDir = -np.linalg.inv(Hess_x_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum()
            Alp = self.line_search(x_0,self.loss(x_0),NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
        self.x_true = x_0 
        self.Hess_x_true = self.Hess(x_0)
        self.loss_true = self.loss(x_0)
        return self.x_true, self.Hess_x_true, self.loss_true

    def BFGS(self,Max_Iter=10**3,EPS = 1e-8):
        # implement BFGS
        Xarray, Losses = [], []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        B_inv = np.identity(self.d)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()
        while eps>=EPS and t<= Max_Iter:
            NewDir = -B_inv@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            s = Alp*NewDir
            x_0 = x_0 + s
            grad_x_0_ = self.grad(x_0)
            y = grad_x_0_ - grad_x_0 
            grad_x_0 = grad_x_0_.copy()
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.loss(x_0))
            # update B
            sy_inner, sy_outer, ss_outer = (s*y).sum(), s@y.T, s@s.T
            B_1 = (sy_inner+(y*(B_inv@y)).sum())/sy_inner**2 * ss_outer
            b_2 = B_inv@sy_outer.T
            B_2 = (b_2+b_2.T)/sy_inner
            B_inv = B_inv + B_1 - B_2
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray    


    def sketch_Newton(self,sketch_size,sketch_method='Gaussian',nnz=None,Max_Iter=10**3,EPS=1e-8):
        # implement stochastic Newton (sketching/subsampling)
        Xarray, Losses = [], []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()                
        while eps>=EPS and t<=Max_Iter:
            H_hat_x_0 = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x_0),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            NewDir = -np.linalg.inv(H_hat_x_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.loss(x_0))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray
    
    def sto_weight_Sket_Newton(self,sketch_size,wei_set='power',power=1,sketch_method='Gaussian',nnz=None,Max_Iter=10**3,EPS=1e-8):
        # implement weighted stochastic Newton (sketching/subsampling)
        Xarray, Losses = [], []
        x_0, grad_x_0, w_H_0 = self.x_0, self.grad(self.x_0), np.identity(self.d)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()
        while eps>=EPS and t<=Max_Iter:
            H_hat_x_0 = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x_0),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            ratio = Weight[wei_set](t,power)
            w_H_0 = ratio*w_H_0 + (1-ratio)*H_hat_x_0
            NewDir = -np.linalg.inv(w_H_0)@grad_x_0
            Inner = (grad_x_0*NewDir).sum()
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.loss(x_0))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # the unweighted norm
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray
       



    ###############################################################################################################################
    # Stochastic Newton Proximal Extragradient (no Hessian averaging)
    def sketch_NPE(self, sketch_size, sketch_method='Gaussian',nnz=None, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()

        x = x_0
        grad_x = grad_x_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            eta = sigma
            x_hat = x - np.linalg.inv(H_hat_x + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(H_hat_x + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            # x = x - eta*grad_x_hat
            # x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = np.linalg.norm(grad_x), t+1
            Xarray.append(x)
            Losses.append(self.loss(x))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray

    
    # Stochastic Newton Proximal Extragradient with Hessian averaging
    def sto_weight_Sket_NPE(self, sketch_size, wei_set='power', power=1, sketch_method='Gaussian',nnz=None, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []
        x_0, grad_x_0, w_H_0 = self.x_0, self.grad(self.x_0), np.identity(self.d)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()

        x = x_0
        grad_x = grad_x_0
        w_H = w_H_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            ratio = Weight[wei_set](t,power)
            w_H = ratio*w_H + (1-ratio)*H_hat_x
            eta = sigma
            x_hat = x - np.linalg.inv(w_H + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(w_H + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            # x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            # x = x - eta*grad_x_hat
            x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = np.linalg.norm(grad_x), t+1
            Xarray.append(x)
            Losses.append(self.loss(x))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray
    


    # Newton Proximal Extragradient
    def NPE(self, Max_Iter=10**3, EPS=1e-8, alpha=1, beta=0.5, sigma_0 = 1):
        Xarray, Losses = [], []

        x_0, grad_x_0, H_0 = self.x_0, self.grad(self.x_0), self.Hess(self.x_0)
        sigma = sigma_0
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        
        start = time()

        x = x_0
        grad_x = grad_x_0
        H = H_0                
        while eps>=EPS and t<=Max_Iter:
            # Construct the sketched Hessian
            # H_hat_x = Sketch_Func[sketch_method](self.n,self.sqrt_hess(x),sketch_size,nnz=nnz)+ self.lambd*np.identity(self.d)
            # ratio = Weight[wei_set](t,power)
            # w_H = ratio*w_H + (1-ratio)*H_hat_x
            H = self.Hess(x)
            eta = sigma
            x_hat = x - np.linalg.inv(H + np.identity(self.d)/eta)@grad_x
            grad_x_hat = self.grad(x_hat)
            gamma = 1+2*eta*self.lambd
            while np.linalg.norm(x_hat - x + eta*grad_x_hat) > alpha*np.sqrt(gamma)*np.linalg.norm(x_hat-x):
                eta = eta*beta
                x_hat = x - np.linalg.inv(H + np.identity(self.d)/eta)@grad_x
                grad_x_hat = self.grad(x_hat)
                gamma = 1+2*eta*self.lambd
            # x = (x - eta*grad_x_hat)/gamma + (1-1/gamma)*x_hat
            # x = x - eta*grad_x_hat
            x = x_hat
            sigma = eta/beta

            grad_x = self.grad(x)
            eps, t = np.linalg.norm(grad_x), t+1
            Xarray.append(x)
            Losses.append(self.loss(x))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray
    

    def Newton(self,Max_Iter=10**3,EPS=1e-8):
        # implement stochastic Newton (sketching/subsampling)
        Xarray, Losses = [], []
        x_0, grad_x_0 = self.x_0, self.grad(self.x_0)
        eps, t = np.linalg.norm(grad_x_0), 0
        Xarray.append(x_0)
        Losses.append(self.loss(x_0))
        


        start = time()                
        while eps>=EPS and t<=Max_Iter:
            H = self.Hess(x_0)
            NewDir = -np.linalg.inv(H)@grad_x_0
            Inner = (grad_x_0*NewDir).sum() 
            Alp = self.line_search(x_0,Losses[-1],NewDir,Inner)
            x_0 = x_0 + Alp*NewDir
            grad_x_0 = self.grad(x_0)
            eps, t = np.linalg.norm(grad_x_0), t+1
            Xarray.append(x_0)
            Losses.append(self.loss(x_0))
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err, Losses - self.loss_true, Time, Xarray

    def AGD(self,Max_Iter=10**3,EPS = 1e-8):
        # implement accelerated gradient descent
        # Strongly convex FISTA, see Acceleration Methods
        Xarray, Losses = [], []
        Err_g = []

        x = self.x_0 
        Xarray.append(x)
        Losses.append(self.loss(x))
        z = x
        # grad_y = self.grad(y)
        eps, t = np.linalg.norm(self.grad(x)), 0
        Err_g.append(eps)

        start = time()
        L = 1
        A = 0

        while eps>=EPS and t<= Max_Iter:
            q = self.lambd/L
            A_new = (2*A+1+np.sqrt(4*A+4*q*A**2+1))/(2*(1-q))
            tau = (A_new-A)*(1+q*A)/(A_new+2*q*A_new*A-q*A**2)
            delta = (A_new-A)/(1+q*A_new)
            y = x + tau*(z-x)

            grad_y = self.grad(y)
            NewDir = -grad_y
            Inner = (grad_y*NewDir).sum() 
            x_new = y - grad_y/L

            if self.loss(x_new) <= self.loss(y) + 0.3*1/L*Inner:
                x = x_new
                z = (1-q*delta)*z + q*delta*y + delta*(x-y)
                A = A_new
                eps, t = np.linalg.norm(self.grad(x)), t+1
                Xarray.append(x)
                Losses.append(self.loss(x))
                Err_g.append(eps)
            else:
                L = L/0.8

            # Alp = self.line_search(y,self.logistic_loss(y),NewDir,Inner,mu= Alp/(0.8))
            # s = Alp*NewDir
            # x_new = y + s
            # c_new = (1 + np.sqrt(1+4*c**2))/2
            # y = x_new + (c-1)/(c_new)*(x_new-x)
            
            # x = x_new
            # c = c_new
            # grad_y = self.grad(y)

            # eps, t = min(np.linalg.norm(grad_y),np.linalg.norm(self.grad(x))), t+1
            # eps, t = np.linalg.norm(self.grad(x)), t+1
            # Xarray.append(x)
            # Losses.append(self.logistic_loss(x))
            # Err_g.append(eps)
        Time = time()-start
        Xarray = np.hstack(Xarray)-self.x_true
        # Err = np.sqrt(((self.Hess_x_true@Xarray)*Xarray).sum(axis=0))
        # Err = np.sqrt((Xarray*Xarray).sum(axis=0))
        Err_g = np.array(Err_g)
        Err_x = np.sqrt((Xarray*Xarray).sum(axis=0))
        return Err_g, Err_x, Losses - self.loss_true, Time
