import numpy as np

from scipy.linalg import eigvalsh_tridiagonal, eigvalsh


np.random.seed(1)

D_in, D_out = 100, 100
A = np.random.randn(D_in,D_out)
Q,_ = np.linalg.qr(A)
D = np.linspace(0.1,1,100)
D = np.diag(D)
A = Q.T @ (D @ Q)
x = np.random.randn(D_out,1)
d = A @ x 
import pickle
D_in,D_out = A.shape

def linear(x):
    return A @ x

def misfit(x):
    u = linear(x)
    return 0.5*(u-d).T.dot(u-d)

def linear_gradient(x):
    return A.T @ (A@x-d)

def misfit_gradient(x):
    return linear_gradient(x)

def misfit_total(x):
    return misfit(x),misfit_gradient(x)

def norm2(x):
    return np.linalg.norm(x,ord=2)
    
from numpy.linalg import pinv

def AARIter(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)

    N = len(x0)
    Xk = np.zeros((N,m))
    Rk = np.zeros((N,m))

    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    misf_v = []

    while cnt < maxit and relres>tol:
        res = -gk
        #print(norm2(res))
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)

        if cnt >= 1:
            k = (cnt-1) % m
            Xk[:,k] = (xk-x_prev).reshape(len(xk))
            Rk[:,k] = (res-res_prev).reshape(len(xk))
        x_prev = xk.copy()
        res_prev = res.copy()

        cnt = cnt+1
        if cnt == 1 or cnt % p != 0:
            xk += alpha*res
            #print('gd: ')
        else:            
            Gamma = pinv(Rk.T@Rk)@(Rk.T@res)
            xk_bar = xk-Xk @ Gamma
            rk_bar = res-Rk @ Gamma
            xk = xk_bar + beta*rk_bar            
            #relres_bar_v.append(norm2(rk_bar)/nd)

        fk,gk = fh(xk)
        if cnt == 1:
            rk_bar = res
        if cnt>1:
            print(cnt-1,"%e %e %e" %(fk,norm2(gk)/nd, norm2(rk_bar)/nd))
    #print(relres_v)
    fk,gk = fh(xk)
    
    with open('aar.ser','wb') as f:
                #pickle.dump((relres_v,relres_bar_v),f)
                pickle.dump(relres_v,f)
    
    print('AAR method converges to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v


def AAR1Iter(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)

    N = len(x0)
    Xk = np.zeros((N,m))
    Rk = np.zeros((N,m))

    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    misf_v = []

    while cnt <= maxit and relres>1e-15:
        res = -gk
        #print(norm2(res))
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)

        if cnt >= 1:
            k = (cnt-1) % m
            Xk[:,k] = (xk-x_prev).reshape(len(xk))
            Rk[:,k] = (res-res_prev).reshape(len(xk))
        x_prev = xk.copy()
        res_prev = res.copy()

        cnt = cnt+1
        if cnt == 1 or cnt % p != 0:
            xk += alpha*res
            #print('gd: ')
        else:            
            Gamma = pinv(Xk.T@Rk)@(Xk.T@res)
            xk_bar = xk-Xk @ Gamma
            rk_bar = res-Rk @ Gamma
            xk = xk_bar + beta*rk_bar
            #print(cnt,"%e %e" % (norm2(gk_bar)/nd,fk_bar))
            relres_bar_v.append(norm2(rk_bar)/nd)

        fk,gk = fh(xk)
        if cnt == 1:
            rk_bar = res
        print("%e %e %e" %(norm2(gk)/nd,fk, norm2(rk_bar)/nd))
    print(relres_bar_v)
    fk,gk = fh(xk)
    
    with open('aar1_m'+str(m)+'.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    
    print('AAR1 method converges to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v


def gd(fh,x0,tol=1e-8,maxit=1e5,beta=0.001):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    while cnt <= maxit and relres> 1e-15:             
        r = -gk
        relres = norm2(r)/nd
        xk = xk+beta*r 
            
        relres = norm2(r)/nd
        if cnt > 0:
            relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('gd.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('gd method converges to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))

def Dpinv(d):    
    y = np.empty(d.shape)
    for j in range(0,len(d)):
        y[j] = 0 if d[j] == 0 else 1./d[j]   
    return y    

def cg(fh,x0,tol=1e-8,maxit=1e5):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    p = r
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    Ap = Ar = A.T@(A@r)    
    while relres>tol and cnt < maxit:                
        alpha = np.vdot(r,r) / np.vdot(Ap,p)
        xk = xk+alpha*p
        r_next = r-alpha*Ap
        Ar_next = A.T@(A@r_next)
        beta = np.vdot(r_next,r_next) / np.vdot(r,r)
        r = r_next
        Ar = Ar_next
        p = r+beta*p
        Ap = Ar+beta*Ap
        relres = norm2(r)/nd
        relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(fk,norm2(gk)/nd))
    fk,gk = fh(xk)
    print(relres_v)
    with open('cg.ser','wb') as f:
        pickle.dump(relres_v,f)     
    print('CG method converges to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
                         

def dPinv(d):
    return 0. if abs(d)<=0 else 1./d

    
def st_bfgs_v4(fh,x0,tol=1e-8,maxit=1e5,tao=1e-16,beta=1.0,m=100):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)
   
    N = len(x0)
    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    #misf_v = []
    
    p = np.zeros(xk.shape)
    q = np.zeros(gk.shape)    
    etas = ([],[],[])    
    phi = 0.
    pqs = []
    mk = -1
    beta_prev = beta    
    while relres>tol and cnt <maxit:
        rk = -gk
        relres = norm2(rk)/nd
        if cnt >0:
            relres_v.append(relres)
            #misf_v.append(fk)
       
        if mk >=0:
            restart = False
            delta_x = xk-x_prev
            delta_r = rk-r_prev                        
            if mk == 0:                
                p = delta_x
                q = delta_r
            else:                
                zeta = np.vdot(p,delta_r)/pq
                p = delta_x-p*zeta
                q = delta_r-q*zeta
            
            pq = np.vdot(p,q)
            pqs.append(abs(pq))
            if pqs[-1] <= (tao)*pqs[0] or mk == m:
                restart = True
                mk = -1
                pqs = []
                beta_prev = beta
                print('restart!')
                                    
            if mk >= 1 and pqs[-1]/pqs[0]>1e-33:  # for numerical consideration
                phi_prev = phi
                #phi = gamma2
                phi = gamma1+gamma2+zeta
                #print(gamma1+zeta)
                eta0 = phi_prev/(beta_prev*(1-gamma1))
                eta1 = (1./beta_prev-phi/beta)/(1-gamma1)
                eta2 = -1./(beta*(1-gamma1))                                                
                etas[0].append(eta0)
                etas[1].append(eta1)
                etas[2].append(eta2)                
                Tk = np.diag(etas[1])+np.diag(etas[0][1:],k=1)+np.diag(etas[2][:-1],k=-1)      
                eig = np.linalg.eigvals(Tk)
                beta_prev = beta                                            
                beta = 2./(min(abs(eig))+max(abs(eig)))
                                                                              
                        
        x_prev = xk.copy()
        r_prev = rk.copy()
        
        mk = mk+1                   
        if mk == 0:
            xk += beta*rk
        else:
            gamma1 = np.vdot(rk,p)/pq
            xk -= p*gamma1
            rk -= q*gamma1          
            xk += beta*rk
            gamma2 = beta*np.vdot(rk,q)/pq
            xk -= p*gamma2            
               
        cnt = cnt+1
        fk,gk = fh(xk)
        if cnt>1:
            print(cnt-1,"%e %e %e" %(fk,norm2(gk)/nd, norm2(rk)/nd))   
            relres_bar_v.append(norm2(rk)/nd) 
    #print(relres_v)
    fk,gk = fh(xk)
    
    eig_tmp = np.linalg.eigvalsh(A.T@A)
    print('****************')
    #print(eig_tmp)
    
    import matplotlib.pyplot as plt
    print(len(eig_tmp),len(eig))
    plt.tick_params(labelsize='large')
    plt.xlabel('real part',fontsize='x-large')
    plt.ylabel('imaginary part',fontsize='x-large')
    plt.scatter(eig_tmp.real,eig_tmp.imag,label='eigenvalue',color='darkorange',marker='o')
    plt.scatter(eig.real,eig.imag,label='Min-AM',color='blue',marker='+')
    plt.legend(fontsize='x-large')
    plt.show()    
    #plt.savefig('eigenvalue' + '_' + 'new.pdf', format='pdf',dpi=120,bbox_inches='tight')
    pickle.dump((eig,eig_tmp), open('minam_eig_' + str(mk) + '.ser', 'wb'))
    with open('st_bfgs.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    print('Minimal memory size Anderson method converges to a relative residual of %e, misfit value of %e in %d iterations.\n' % (norm2(gk)/nd,fk,cnt))
    return relres_v        

        


x0 = np.zeros((D_out,1))
rv0 = st_bfgs_v4(misfit_total,x0,tol=1e-15,maxit=D_out+2,tao=0,beta=1.)


#
