import numpy as np
import scipy.io as io

D_in, D_out = 100, 100

D_in, D_out = 3000, 1000
matr = io.loadmat('fidap029.mat')
A = matr['A']
b = matr['b']

#import pickle
#with open('data2.ser','wb') as f:
#    pickle.dump((A,x,d),f)

import pickle
#with open('data2.ser','rb') as f:
#    A,x,d = pickle.load(f)

D_in,D_out = A.shape

D = A.diagonal()

def norm2(x):
    return np.linalg.norm(x,ord=2)

def misfit_total(x):
    #return misfit(x),misfit_gradient(x)
    g = A@x-b
    #print(D.shape)
    #print(g.shape)
    return norm2(g), g/D.reshape(g.shape)

    
from numpy.linalg import pinv

def AM(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)

    N = len(x0)
    Xk = np.zeros((N,m))
    Rk = np.zeros((N,m))

    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []
    misf_v = []

    while cnt <= maxit and relres>1e-15:
        res = -gk
        #print(norm2(res))
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)

        if cnt >= 1:
            k = (cnt-1) % m
            Xk[:,k] = (xk-x_prev).reshape(len(xk))
            Rk[:,k] = (res-res_prev).reshape(len(xk))
        x_prev = xk.copy()
        res_prev = res.copy()

        cnt = cnt+1
        if cnt == 1 or cnt % p != 0:
            xk += alpha*res
            #print('gd: ')
        else:            
            Gamma = pinv(Rk.T@Rk)@(Rk.T@res)
            xk_bar = xk-Xk @ Gamma
            rk_bar = res-Rk @ Gamma
            xk = xk_bar + beta*rk_bar
            #print(cnt,"%e %e" % (norm2(gk_bar)/nd,fk_bar))
            relres_bar_v.append(norm2(rk_bar)/nd)

        fk,gk = fh(xk)
        if cnt == 1:
            rk_bar = res
        print("%e %e %e" %(norm2(gk)/nd,fk, norm2(rk_bar)/nd))
    print(relres_v)
    fk,gk = fh(xk)
    
    with open('am.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    
    print('AM method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v

def cr(fh,x0,tol=1e-8,maxit=1e5):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    p = r
    nd = norm2(r)
    relres = tol+1
    relres_v = []
    
    cnt =0 
    Ap = Ar = (A@r)/D.reshape(r.shape)    
    while cnt < maxit and relres>1e-15:             
        #relres = norm2(r)/nd
        #if cnt > 0:
        #    relres_v.append(relres)
            #misf_v.append(fk)
            
        alpha = (r.T@Ar) / (Ap.T@Ap)
        xk = xk+alpha*p
        r_next = r-alpha*Ap
        Ar_next = (A@r_next)/D.reshape(r.shape)
        beta = (r_next.T@Ar_next) / (r.T@Ar)
        r = r_next
        Ar = Ar_next
        p = r+beta*p
        Ap = Ar+beta*Ap
        #relres = norm2(r)/nd
        #relres_v.append(relres)
        #print('%e'%(relres))
        cnt = cnt+1
        fk,gk = fh(xk)
        relres = norm2(gk)/nd
        relres_v.append(relres)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('cr.ser','wb') as f:
                pickle.dump(relres_v,f)
    
    print('CR method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))

def gd(fh,x0,tol=1e-8,maxit=1e5,beta=0.001):
    xk = x0
    fk,gk = fh(xk)
    r = -gk
    nd = norm2(r)
    relres = tol+1
    relres_v = []

    cnt =0
    while cnt <= maxit and relres> 1e-15:
        r = -gk
        relres = norm2(r)/nd
        xk = xk+beta*r

        relres = norm2(r)/nd
        if cnt > 0:
            relres_v.append(relres)
        cnt = cnt+1
        fk,gk = fh(xk)
        print(cnt,"%e %e" %(norm2(gk)/nd,fk))
    fk,gk = fh(xk)
    print(relres_v)
    with open('gd.ser','wb') as f:
                pickle.dump(relres_v,f)

    print('gd method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))


def ca(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0):
    cnt = 0
    xk = x0
    fk,gk = fh(xk)
   
    N = len(x0)
    nd = norm2(gk)
    relres = tol+1
    relres_v = []
    relres_bar_v = []    
    misf_v = []
    P = np.zeros((N,2))
    Q = np.zeros((N,2))

    while cnt <=maxit and relres>1e-15:
        res = -gk
        relres = norm2(res)/nd
        if cnt >0:
            relres_v.append(relres)
            misf_v.append(fk)
       
        if cnt >=1:
            delta_xk = xk-x_prev
            delta_rk = res-res_prev
            p = delta_xk
            q = delta_rk
            
            if cnt == 1:
                nq = norm2(q)
                p /= nq
                q /= nq
                P[:,1] = p.reshape(N)
                Q[:,1] = q.reshape(N)
                
                Xk = p.copy()
                Rk = q.copy()

            elif cnt == 2:
                xi = -np.dot(q.reshape(N),Q[:,1])
                q = q+xi*Q[:,1].reshape(N,1)
                p = p+xi*P[:,1].reshape(N,1)
                nq = norm2(q)
                q /= nq
                p /= nq
                P = np.c_[P[:,1],p]
                Q = np.c_[Q[:,1],q]
                
                Xk = np.c_[Xk,p]
                Rk = np.c_[Rk,q]
            else:
                xi = -Q.T @ q
                q += Q@xi
                p += P@xi
                nq = norm2(q)
                q /= nq
                p /= nq
                P = np.c_[P[:,1],p]
                Q = np.c_[Q[:,1],q]
                
                Xk = np.c_[Xk,p]
                Rk = np.c_[Rk,q]

                        
        x_prev = xk.copy()
        res_prev = res.copy()
        cnt = cnt+1
       
        Gamma = Q.T @ res
        x_bar = xk - P@Gamma
        r_bar = res-Q@Gamma
        xk = x_bar + beta*r_bar
        
        if cnt > 1:
            #print(cnt,r_bar.shape, Rk.shape)
            #print(cnt,"%e %e" % (norm2(gk_bar)/nd,fk_bar))
            relres_bar_v.append(norm2(r_bar)/nd)

        fk,gk = fh(xk)
        print(cnt,'%e %e %e' % (norm2(gk)/nd,fk, norm2(r_bar)/nd))
    #print(relres_v)
    
    with open('ca.ser','wb') as f:
                pickle.dump((relres_v,relres_bar_v),f)
    
    print('Conjugate Anderson method converge to a relative residual of %e, misfit value of %e in %d iterations.\n' % (relres,fk,cnt))
    return relres_v
            

x0 = np.zeros((D_out,1))
#AARIter(fh,x0,tol=1e-8,maxit=1e5,alpha=0.001,beta=1.0,m=100,p=1):
rv0 = AM(misfit_total,x0,tol=1e-8,maxit=50,alpha=1e-3,beta=0.1,m=50,p=1)

x0 = np.zeros((D_out,1))
rv1 = ca(misfit_total,x0,tol=1e-8,maxit=50,alpha=1,beta=0.1)

#cr(misfit_total,x0)
x0 = np.zeros((D_out,1))
rv2 = cr(misfit_total,x0,maxit=50)

x0 = np.zeros((D_out,1))
rv2 = gd(misfit_total,x0,maxit=50,beta=0.1)


'''import matplotlib.pyplot as plt

name = 'Anderson_CR'

plt.clf()
plt.xlabel('epochs')
plt.ylabel('Train Loss')
plt.yscale('log')

plt.plot(np.arange(1,len(rv1)+1),rv1,label='Anderson',color='crimson')
plt.plot(np.arange(1,len(rv2)+1),rv2,label='CR',color='blue')

plt.legend()
plt.savefig(name+'_'+'.pdf',format='pdf',dpi=120,bbox_inches='tight')'''
