import os
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from xitorch.optimize import rootfinder
import numpy as np
from scipy.integrate import odeint
from scipy.fftpack import diff 
from scipy.integrate import solve_ivp
import scipy


def create_matrixl(n): #-x(n+1)+2x(n)-x(n-1) matrix
    matrix = np.zeros((n, n))
    np.fill_diagonal(matrix, 2) 
    for i in range(1, n):
        matrix[i, i-1] = -1  
        matrix[i-1, i] = -1  
    matrix[0, n-1] = -1
    matrix[n-1,0] = -1
    return matrix

def create_matrixs(n): #x(n+1)-x(n-1) matrix
    matrix = np.zeros((n, n))
    for i in range(1, n):
        matrix[i, i-1] = -1  
        matrix[i-1, i] = 1  
    matrix[0, n-1] = -1
    matrix[n-1,0] = 1
    return matrix



x_len = 20  #20
x_num = 40   #40
delta_x=x_len/x_num
dx = x_len / (x_num - 1.0)
x = np.linspace(0, (1-1.0/x_num)*x_len, x_num)  

def kdv_exact(x, c):  
    u = 0.5*c*np.cosh(0.5*np.sqrt(c)*x)**(-2)
    return u

u0 = kdv_exact(x-0.33*x_len, 0.75) + kdv_exact(x-0.65*x_len, 0.4)   #approximation
T = 30  #30
t_stepsize=0.01  #0.01
T_len = int(T/t_stepsize)+1

t = np.linspace(0, T, T_len)

T_test =30
t_stepsize_test=t_stepsize

T_len_test = int(T_test/t_stepsize_test)+1
t_test = np.linspace(0, T_test, T_len_test)

lr = 0.0001
iterations =10000#10000
print_every = 500

layers=5
width=200  #200

dataorder_set=1e-12
rootorder_set=1e-12

dis_S=create_matrixs(x_num)/delta_x/2
dis_S_torch=torch.from_numpy(dis_S)
dis_L=create_matrixl(x_num)
dis_L_torch=torch.from_numpy(dis_L)

def AVFgrad_torch(u,v):
    u=u.unsqueeze(1).permute(0,2,1)
    v=v.unsqueeze(1).permute(0,2,1)
    #print(u)
    avfgrad=1/(2*delta_x**2)*torch.matmul(dis_L_torch,(u+v))-(u-v)**2-3*u*v
    return avfgrad






class Bgrad(torch.nn.Module):

    def __init__(self, dim=x_num,layers=layers,ind=x_num*2, outd=(x_num)**2, width=width, initializer='default'):
        super(Bgrad, self).__init__()
        self.dim = dim
        self.ind = ind
        self.outd = outd
        self.layers = layers
        self.width = width
        self.initializer = initializer
        self.modus = self.__init_modules()


    def forward(self,u,v):
        x0x1=torch.hstack([u,v])
        x0x1=x0x1.unsqueeze(1)
        avf=AVFgrad_torch(u,v)
        x0=v.unsqueeze(1)
        x1=u.unsqueeze(1)
        x0T=x0.permute(0,2,1)
        x1T=x1.permute(0,2,1)

        x=x0x1
        for i in range(1, self.layers):
            LinM = self.modus['LinM{}'.format(i)]
            NonM = self.modus['NonM{}'.format(i)]
            x = NonM(LinM(x))
        x = self.modus['LinMout'](x)
        
        #print(x)
        n=len(x)

        A=torch.reshape(x,(n,self.dim,self.dim))
        AT=A.permute(0,2,1)
        AAA=A-AT


        savf=torch.matmul(dis_S_torch,avf)
        sBgrad=savf.permute(0,2,1).squeeze(1)+torch.bmm(torch.matmul(dis_S_torch,AAA),(x1T-x0T)).squeeze(2)
        return sBgrad
    
    def Bfun(self,u,v,h):
        return self.forward(u,v)*h-(u-v)
    
    def predict(self,xh,steps=1, keepinitx=True, returnnp=False):
        size = len(xh.size())
        x0= xh
        h=t_stepsize
        #print(x0)
        #print(h)
        pred = [x0]
        #print(pred[-1])
        #print(AVFgrad_torch(pred[-1],pred[-1]).permute(0,2,1).squeeze(1))

        for _ in range(steps):
            avf=AVFgrad_torch(pred[-1],pred[-1])
            AVF_nextpoint=pred[-1]+h*avf.permute(0,2,1).squeeze(1)
            next=rootfinder(self.Bfun, AVF_nextpoint,
                             params=(pred[-1],h),f_tol=1e-10,method="broyden1" )

            pred.append(next)

        if keepinitx:
            steps = steps + 1

        res = torch.cat(pred, dim=-1).view([-1, steps, self.dim][2 - size:])
        return res.cpu().detach().numpy() if returnnp else res
    

    
    def __init_modules(self):
        modules = nn.ModuleDict()
        if self.layers > 1:
            modules['LinM1'] = nn.Linear(self.ind, self.width).double()
            modules['NonM1'] = torch.nn.Tanh()
            for i in range(2, self.layers):
                modules['LinM{}'.format(i)] = nn.Linear(self.width, self.width).double()
                modules['NonM{}'.format(i)] = torch.nn.Tanh()
                
            modules['LinMout'] = nn.Linear(self.width, self.outd).double()
        else:
            modules['LinMout'] = nn.Linear(self.ind, self.outd).double()

            
        return modules




def kdv(u, t, L):
    ux = diff(u, period=L)
    uxxx = diff(u, period=L, order=3)

    dudt = -6*u*ux - uxxx
    return dudt

def kdv_solution(u0, t, L):
    sol = odeint(kdv, u0, t, args=(L,), mxstep=5000,rtol=1e-10,atol=1e-10)
    return sol


if __name__ == "__main__":
    sol = kdv_solution(u0, t, x_len)
    u_q_train=sol
    u_train=u_q_train
    u0_q_test=u_q_train[-1]

    sol_test = kdv_solution(u0_q_test, t_test, x_len)
    u_q_test=sol_test


    u_test=u_q_test


    def energy(u):
        l=len(u) 
        h=np.zeros((l,1))
        for i in range(l):
            u_j1=np.roll(u[i],-1)
            ux = 1/delta_x*(u_j1-u[i])
            h[i]=np.sum(0.5*ux**2-u[i]**3)

        return h

    def AVFgrad(u,v):
        avfgrad=1/(2*delta_x**2)*dis_L@(u+v)-(u-v)**2-3*u*v
        return avfgrad

    def avffun(u,v,h):
        return dis_S @ AVFgrad(u,v)*h-(u-v)



    u_traintorch=torch.from_numpy(u_train)
    u_testtorch=torch.from_numpy(u_test)
    AVFpred = u_traintorch[-1].unsqueeze(0).cpu().detach().numpy()


    u_traintorch_y=u_traintorch[1:T_len-1]
    u_traintorch_x=u_traintorch[0:T_len-2]

    u_testtorch_y=u_testtorch[1:T_len_test-1]
    u_testtorch_x=u_testtorch[0:T_len_test-2]

    a=np.zeros_like(AVFpred[-1])

    t1=time.time()
    for _ in range(T_len_test-1):
        solution=scipy.optimize.root(lambda x : avffun(x,AVFpred[-1],t_stepsize),a,method='broyden2',tol=1e-10)
        next=solution.x

        AVFpred=np.append(AVFpred, np.expand_dims(next, axis=0), axis=0)
    t2=time.time()
    t_AVF=t2-t1


    Bnet=Bgrad()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Bnet.parameters(), lr=lr)

    loss_history = []
    for epoch in range(iterations):

        grad_pred = Bnet(u_traintorch_y,u_traintorch_x)
        loss = criterion(grad_pred, (u_traintorch_y-u_traintorch_x)/t_stepsize)
   
        grad_pred_test = Bnet(u_testtorch_y,u_testtorch_x)       
        loss_test = criterion(grad_pred_test, (u_testtorch_y-u_testtorch_x)/t_stepsize)
        
        if epoch %print_every ==0 or epoch+1 == iterations:
            #print(epoch, loss.item())
            print('{:<9}Train loss: {:<25}Test loss: {:<25}'.format(epoch, loss.item(), loss_test.item()), flush=True)
            loss_history.append([epoch, loss.item(), loss_test.item()])  
 
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step() 

    loss_history=np.array(loss_history)


    t1=time.time()
    flow_pred=(Bnet.predict(u_testtorch[0].unsqueeze(0),steps=T_len_test-1).squeeze(0)).cpu().detach().numpy()
    t2=time.time()
    t_net=t2-t1

    flow_energy=energy(flow_pred).squeeze(1)
    #print(flow_energy)
    true_energy=energy(u_test).squeeze(1)
    AVF_energy=energy(AVFpred).squeeze(1)

    l=T_len_test
    AVF_error=np.zeros((l,))
    flow_error=np.zeros((l,))

    for i in range(l):
        AVF_error[i]=np.sum((AVFpred[i]-u_test[i])**2)
        flow_error[i]=np.sum((flow_pred[i]-u_test[i])**2)

    def energy(u):
        l=len(u) 
        h=np.zeros((l,1))
        for i in range(l):
            #ux = diff(u[i], period=x_len)
            u_j1=np.roll(u[i],-1)
            ux = 1/delta_x*(u_j1-u[i])

            print(u[i])
            print(u_j1)

            #print(ux)
            #print(u[i])
            h[i]=np.sum(0.5*ux**2-u[i]**3)

        return h
    en_up=true_energy[0]+0.001
    en_down=true_energy[0]-0.001

    vmin = 0   # 最低浓度值
    vmax = 0.4   # 最高浓度值

    print("t_net"   ,t_net) 
    print("t_AVF"   ,t_AVF) 



    import matplotlib.pyplot as plt

    plt.figure(figsize=[8, 8])    

    plt.subplot(331)

    plt.imshow(sol[::-1, :], extent=[0,x_len,0,T], vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('t')
    plt.axis('auto')
    plt.title('Traindata')


    plt.subplot(334)

    plt.imshow(sol_test[::-1, :], extent=[0,x_len,0,T_test], vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('t')
    plt.axis('auto')
    plt.title('Ground truth')



    plt.subplot(335)

    plt.imshow(flow_pred[::-1, :], extent=[0,x_len,0,T_test], vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('t')
    plt.axis('auto')
    plt.title('Predicted solution')



    plt.subplot(336)

    plt.imshow(AVFpred[::-1, :], extent=[0,x_len,0,T_test], vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('x')
    plt.ylabel('t')
    plt.axis('auto')
    plt.title('AVF solution')





    plt.subplot(337)

    plt.plot(t_test, flow_error, color='r', label='Proposed method')
    plt.plot(t_test, AVF_error, color='green', label='AVF')
    #plt.ylim([-1.5,2])
    plt.title('Global error')
    plt.legend()




    plt.subplot(338)

    plt.plot(t_test, flow_energy, color='r', label='Proposed method')
    #plt.plot(t_test, true_energy, color='b', label='Ground truth')
    plt.plot(t_test, AVF_energy, color='green', label='AVF')
    plt.ylim([en_down,en_up])
    plt.title('Energy')
    plt.legend()




    plt.tight_layout() 


    plt.savefig('./kdv.pdf')

    plt.show()
