import numpy as np

D_in, D_out = 500, 100

#A = np.random.randn(D_in,D_out)
#x = np.random.randn(D_out,1)
#d = A @ x

#import pickle
#with open('data2.ser','wb') as f:
#    pickle.dump((A,x,d),f)

import pickle
with open('data2.ser','rb') as f:
    A,x,d = pickle.load(f)

D_in,D_out = A.shape

M = 0.1

def norm2(x):
    return np.linalg.norm(x,ord=2)

def linear(x):
    return A @ x

def misfit0(x):
    u = linear(x)
    return 0.5*(u-d).T.dot(u-d)

def misfit(x):
    u = linear(x)
    return 0.5*(u-d).T.dot(u-d)+M*(norm2(x)**3)/3

def cubic_gradient(x):
    return A.T @ (A@x-d)+M*norm2(x)*x

def misfit_gradient(x):
    return cubic_gradient(x)

def misfit_total(x):
    return misfit(x),misfit_gradient(x)

    
from numpy.linalg import pinv

def AM(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)

    N = len(x0)
    Xk = np.zeros((N,m))
    Rk = np.zeros((N,m))

    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    theta_v = []
    misf_v = []

    while cnt < maxit:
        res = -gk
        #print(norm2(res))
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)

        if cnt >= 1:
            k = (cnt-1) % m
            Xk[:,k] = (xk-x_prev).reshape(len(xk))
            Rk[:,k] = (res-res_prev).reshape(len(xk))
        x_prev = xk.copy()
        res_prev = res.copy()

        cnt = cnt+1
        if cnt == 1 or cnt % p != 0:
            xk += alpha*res
            #print('gd: ')
        else:            
            Gamma = pinv(Rk.T@Rk)@(Rk.T@res)
            xk_bar = xk-Xk @ Gamma
            rk_bar = res-Rk @ Gamma
            xk = xk_bar + beta*rk_bar
            relres_bar_v.append(norm2(rk_bar)/nd)
            theta_v.append(norm2(rk_bar)/norm2(res))

        fk,gk = fh(xk)
        if cnt == 1:
            rk_bar = res
        print("%e %e %e" %(norm2(gk)/nd,fk, norm2(rk_bar)/nd))
    print(relres_v)
    fk,gk = fh(xk)
    
    with open('am_{}.ser'.format(M),'wb') as f:
                pickle.dump((relres_v,relres_bar_v,theta_v),f)
    
    print('AM method converge to a relative residual of %e, misfit value of %e, %e in %d iterations.\n' % (relres,fk,misfit0(xk),cnt))
    return relres_v

def cr(fh,x0,tol=1e-8,maxit=1e5):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    p = r
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    Ap = Ar = A.T@(A@r)    
    while cnt < maxit:             
        relres = norm2(gk)/nd
        #if cnt > 0:
        #    relres_v.append(relres)
            #misf_v.append(fk)
            
        alpha = (r.T@Ar) / (Ap.T@Ap)
        xk = xk+alpha*p
        r_next = r-alpha*Ap
        Ar_next = A.T@(A@r_next)
        beta = (r_next.T@Ar_next) / (r.T@Ar)
        r = r_next
        Ar = Ar_next
        p = r+beta*p
        Ap = Ar+beta*Ap
        relres = norm2(gk)/nd
        relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('cr.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('CR method converge to a relative residual of %e, misfit value of %e, %e in %d iterations.\n' % (relres,fk,misfit0(xk),cnt))


def gd(fh,x0,tol=1e-8,maxit=1e5,beta=0.001):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    while cnt < maxit and relres> 1e-15:             
        r = -gk
        relres = norm2(r)/nd
        xk = xk+beta*r 
            
        relres = norm2(r)/nd
        if cnt > 0:
            relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('gd.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('gd method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))


def ca(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)
   
    N = len(x0)
    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []   
    theta_v = [] 
    misf_v = []
    P = np.zeros((N,2))
    Q = np.zeros((N,2))
    
    np_v = []
    nq_v = []
    
    cp = 1
    cq = 1

    while cnt <maxit:
        res = -gk
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)
       
        if cnt >=1:
            delta_xk = xk-x_prev
            delta_rk = res-res_prev
            p = delta_xk
            q = delta_rk
            
            xi = pinv(Q.T@Q)@(Q.T@q)
            np_v.append(norm2(P@xi)/norm2(delta_xk))
            nq_v.append(norm2(Q@xi)/norm2(delta_rk))

            if norm2(P@xi)<cp*norm2(p) and norm2(Q@xi)<cq*norm2(q):
                p = p-P@xi
                q = q-Q@xi
                P = np.c_[P[:,1],p]
                Q = np.c_[Q[:,1],q]                          
            else:
                pass
            
                       
            #P = np.c_[P[:,1],p]
            #Q = np.c_[Q[:,1],q]                          
        
                        
        x_prev = xk.copy()
        res_prev = res.copy()
        cnt = cnt+1
       
        Gamma = pinv(Q.T@Q)@(Q.T @ res)
        x_bar = xk - P@Gamma
        r_bar = res-Q@Gamma
        xk = x_bar + beta*r_bar
               
        if cnt > 1:
            #print(cnt,r_bar.shape, Rk.shape)
            relres_bar_v.append(norm2(r_bar)/nd)
            theta_v.append(norm2(r_bar)/norm2(res))


        fk,gk = fh(xk)
        print(cnt,'%e %e %e' % (norm2(gk)/nd,fk, norm2(r_bar)/nd))
    #print(relres_v)
    
    with open('mca_{}.ser'.format(M),'wb') as f:
                pickle.dump((relres_v,relres_bar_v,np_v,nq_v,theta_v),f)
    
    print('Modified Conjugate Anderson method converge to a relative residual of %e, misfit value of %e, %e in %d iterations.\n' % (relres,fk,misfit0(xk),cnt))
    return relres_v
            

x0 = np.zeros((D_out,1))
#AARIter(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
rv0 = AM(misfit_total,x0,tol=1e-8,maxit=51,alpha=1e-3,beta=0.001,m=51,p=1)

x0 = np.zeros((D_out,1))
rv1 = ca(misfit_total,x0,tol=1e-8,maxit=51,alpha=1,beta=0.001)

#cr(misfit_total,x0)
#x0 = np.zeros((D_out,1))
#rv2 = cr(misfit_total,x0,maxit=51)

x0 = np.zeros((D_out,1))
rv2 = gd(misfit_total,x0,maxit=51,beta=0.001)


'''import matplotlib.pyplot as plt

name = 'Anderson_CR'

plt.clf()
plt.xlabel('epochs')
plt.ylabel('Train Loss')
plt.yscale('log')

plt.plot(np.arange(1,len(rv1)+1),rv1,label='Anderson',color='crimson')
plt.plot(np.arange(1,len(rv2)+1),rv2,label='CR',color='blue')

plt.legend()
plt.savefig(name+'_'+'.pdf',format='pdf',dpi=120,bbox_inches='tight')'''
