import numpy as np

D_in, D_out = 500, 100

#A = np.random.randn(D_in,D_out)
#x = np.random.randn(D_out,1)
#d = A @ x

#import pickle
#with open('data2.ser','wb') as f:
#    pickle.dump((A,x,d),f)

import pickle
with open('data2.ser','rb') as f:
    A,x,d = pickle.load(f)

D_in,D_out = A.shape

def linear(x):
    return A @ x

def misfit(x):
    u = linear(x)
    return 0.5*(u-d).T.dot(u-d)

def linear_gradient(x):
    return A.T @ (A@x-d)

def misfit_gradient(x):
    return linear_gradient(x)

def misfit_total(x):
    return misfit(x),misfit_gradient(x)

def norm2(x):
    return np.linalg.norm(x,ord=2)
    
from numpy.linalg import pinv

def AM(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)

    N = len(x0)
    Xk = np.zeros((N,m))
    Rk = np.zeros((N,m))

    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    misf_v = []

    while cnt <= maxit and relres>1e-15:
        res = -gk
        #print(norm2(res))
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)

        if cnt >= 1:
            k = (cnt-1) % m
            Xk[:,k] = (xk-x_prev).reshape(len(xk))
            Rk[:,k] = (res-res_prev).reshape(len(xk))
        x_prev = xk.copy()
        res_prev = res.copy()

        cnt = cnt+1
        if cnt == 1 or cnt % p != 0:
            xk += alpha*res
            #print('gd: ')
        else:            
            Gamma = pinv(Rk.T@Rk)@(Rk.T@res)
            xk_bar = xk-Xk @ Gamma
            rk_bar = res-Rk @ Gamma
            xk = xk_bar + beta*rk_bar
            relres_bar_v.append(norm2(rk_bar)/nd)

        fk,gk = fh(xk)
        if cnt == 1:
            rk_bar = res
        print("%e %e %e" %(norm2(gk)/nd,fk, norm2(rk_bar)/nd))
    print(relres_v)
    fk,gk = fh(xk)
    
    with open('am.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    
    print('AM method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v

def cr(fh,x0,tol=1e-8,maxit=1e5):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    p = r
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    Ap = Ar = A.T@(A@r)    
    while cnt <= maxit and relres> 1e-15:             
        relres = norm2(r)/nd
        #if cnt > 0:
        #    relres_v.append(relres)
            #misf_v.append(fk)
            
        alpha = (r.T@Ar) / (Ap.T@Ap)
        xk = xk+alpha*p
        r_next = r-alpha*Ap
        Ar_next = A.T@(A@r_next)
        beta = (r_next.T@Ar_next) / (r.T@Ar)
        r = r_next
        Ar = Ar_next
        p = r+beta*p
        Ap = Ar+beta*Ap
        #relres = norm2(r)/nd
        #relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        relres = norm2(gk)/nd
        relres_v.append(relres)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('cr.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('CR method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))


def gd(fh,x0,tol=1e-8,maxit=1e5,beta=0.001):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    while cnt <= maxit and relres> 1e-15:             
        r = -gk
        relres = norm2(r)/nd
        xk = xk+beta*r 
            
        relres = norm2(r)/nd
        if cnt > 0:
            relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('gd.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('gd method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))


def ca(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)
   
    N = len(x0)
    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []    
    misf_v = []
    P = np.zeros((N,2))
    Q = np.zeros((N,2))

    while cnt <=maxit and relres>1e-15:
        res = -gk
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)
       
        if cnt >=1:
            delta_xk = xk-x_prev
            delta_rk = res-res_prev
            p = delta_xk
            q = delta_rk
            
            if cnt == 1:
                nq = norm2(q)
                p /= nq
                q /= nq
                P[:,1] = p.reshape(N)
                Q[:,1] = q.reshape(N)
                
                Xk = p.copy()
                Rk = q.copy()

            elif cnt == 2:
                xi = -np.dot(q.reshape(N),Q[:,1])
                q = q+xi*Q[:,1].reshape(N,1)
                p = p+xi*P[:,1].reshape(N,1)
                nq = norm2(q)
                q /= nq
                p /= nq
                P = np.c_[P[:,1],p]
                Q = np.c_[Q[:,1],q]
                
                Xk = np.c_[Xk,p]
                Rk = np.c_[Rk,q]
            else:
                xi = -Q.T @ q
                q += Q@xi
                p += P@xi
                nq = norm2(q)
                q /= nq
                p /= nq
                P = np.c_[P[:,1],p]
                Q = np.c_[Q[:,1],q]
                
                Xk = np.c_[Xk,p]
                Rk = np.c_[Rk,q]
           
                        
        x_prev = xk.copy()
        res_prev = res.copy()
        cnt = cnt+1
       
        Gamma = Q.T @ res
        x_bar = xk - P@Gamma
        r_bar = res-Q@Gamma
        xk = x_bar + beta*r_bar
        
        if cnt > 1:
            relres_bar_v.append(norm2(r_bar)/nd)

        fk,gk = fh(xk)
        print(cnt,'%e %e %e' % (norm2(gk)/nd,fk, norm2(r_bar)/nd))
    #print(relres_v)
    
    with open('ca.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    
    print('Conjugate Anderson method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v
            

x0 = np.zeros((D_out,1))
#AARIter(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
rv0 = AM(misfit_total,x0,tol=1e-8,maxit=50,alpha=1e-3,beta=0.001,m=50,p=1)

x0 = np.zeros((D_out,1))
rv1 = ca(misfit_total,x0,tol=1e-8,maxit=50,alpha=1,beta=0.001)

#cr(misfit_total,x0)
x0 = np.zeros((D_out,1))
rv2 = cr(misfit_total,x0,maxit=50)

x0 = np.zeros((D_out,1))
rv2 = gd(misfit_total,x0,maxit=50,beta=0.001)


'''import matplotlib.pyplot as plt

name = 'Anderson_CR'

plt.clf()
plt.xlabel('epochs')
plt.ylabel('Train Loss')
plt.yscale('log')

plt.plot(np.arange(1,len(rv1)+1),rv1,label='Anderson',color='crimson')
plt.plot(np.arange(1,len(rv2)+1),rv2,label='CR',color='blue')

plt.legend()
plt.savefig(name+'_'+'.pdf',format='pdf',dpi=120,bbox_inches='tight')'''
