import os
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from xitorch.optimize import rootfinder
from scipy.integrate import odeint


#henonheiles
#Initial point setting （q1,q2,p1,p2,h） [0.3, 0.3,-0.2, -0.2, 0.3]
x0h_set=torch.tensor([[0.3, 0.3,-0.2, -0.2, 0.3]],dtype = torch.float64)
h=x0h_set[0][4].cpu().detach().numpy()
train_num_set =100
test_num_set = 300
ttrain = np.linspace(0, train_num_set*h, train_num_set+1)
ttest = np.linspace(0, test_num_set*h, test_num_set+1)

dataorder_set=1e-14
rootorder_set=1e-12

lr_set = 0.00001
iterations_set =30000
print_every_set = 500

layers_set=10
width_set=100


def ENERGY(x):
    q1 = x[..., 0]
    q2 = x[..., 1]
    p1 = x[..., 2]
    p2 = x[..., 3]
    e=0.5*(p1**2+p2**2)+0.5*(q1**2+q2**2)+q2*(q1**2)-1/3*(q2**3)
    return e


def SAVFgrad(u,v):
    uq1 = u[..., 0]
    uq2 = u[..., 1]
    up1 = u[..., 2]
    up2 = u[..., 3]

    vq1 = v[..., 0]
    vq2 = v[..., 1]
    vp1 = v[..., 2]
    vp2 = v[..., 3]

    dq1=2*uq1*uq2/3 +uq1*vq2/3 + uq1/2 + uq2*vq1/3 + 2*vq1*vq2/3 + vq1/2
    dq2=uq1**2/3 + uq1*vq1/3 - uq2**2/3 - uq2*vq2/3 + uq2/2 + vq1**2/3 - vq2**2/3 + vq2/2
    dp1=up1/2 + vp1/2
    dp2=up2/2 + vp2/2
    return torch.vstack([dp1,dp2,-dq1,-dq2]).T

def avffun(u,v,h):
    return SAVFgrad(u,v)*h-(u-v)


def avf_find_4(v,h):
    r1 = 1 / (2 - 2 ** (1 / 3))
    r2 = - 2 ** (1 / 3) / (2 - 2 ** (1 / 3))
    next=rootfinder(avffun, v, params=(v,r1*h),f_tol=rootorder_set*1e-2 )
    next2=rootfinder(avffun, next, params=(next,r2*h),f_tol=rootorder_set*1e-2 )
    next3=rootfinder(avffun, next2, params=(next2,r1*h),f_tol=rootorder_set*1e-2 )
    return next3


def dudt(x, t):
    q1 = x[..., 0]
    q2 = x[..., 1]
    p1 = x[..., 2]
    p2 = x[..., 3]
    dHdq1 = q1+2*q1*q2
    dHdq2 = q2+q1**2-q2**2
    dHdp1 = p1
    dHdp2 = p2
    dHdp = np.hstack([dHdp1, dHdp2])
    dHdq = np.hstack([dHdq1, dHdq2])
    return np.hstack([dHdp, -dHdq])  

def dudt_solution(u0, t):
    sol = odeint(dudt, u0, t, mxstep=5000,rtol=dataorder_set,atol=dataorder_set)
    return sol




s=torch.tensor([[0,0,1,0],
                [0,0,0,1],
                [-1,0,0,0],
                [0,-1,0,0],],dtype = torch.float64)




class Bgrad(torch.nn.Module):

    def __init__(self, dim=4,layers=layers_set,ind=8, outd=16, width=width_set, initializer='default'):
        super(Bgrad, self).__init__()
        self.dim = dim
        self.ind = ind
        self.outd = outd
        self.layers = layers
        self.width = width
        self.initializer = initializer
        self.modus = self.__init_modules()

    def forward(self,u,v):
        savf=SAVFgrad(u,v)
        x0x1=torch.hstack([u,v])
        x0x1=x0x1.unsqueeze(1)
        x0=v.unsqueeze(1)
        x1=u.unsqueeze(1)
        x0T=x0.permute(0,2,1)
        x1T=x1.permute(0,2,1)
        x=x0x1
        for i in range(1, self.layers):
            LinM = self.modus['L{}'.format(i)]
            NonM = self.modus['N{}'.format(i)]
            x = NonM(LinM(x))
        x = self.modus['Lout'](x)
        n=len(x)

        #Make a skew symmetric matrix
        A=torch.reshape(x,(n,self.dim,self.dim))
        AT=A.permute(0,2,1)
        AAA=A-AT

        sBgrad=savf+torch.bmm(torch.matmul(s,AAA),(x1T-x0T)).squeeze(2)
        return sBgrad
    
    def Bfun(self,u,v,h):
        return self.forward(u,v)*h-(u-v)
    
    def predict(self,xh,steps=1):
        #Generating trajectories
        size = len(xh.size())
        x0, h = xh[..., :-1], xh[..., -1:]
        pred = [x0]

        for _ in range(steps):
            next=rootfinder(self.Bfun, pred[-1], params=(pred[-1],h),f_tol=rootorder_set )
            pred.append(next)

        steps = steps + 1

        res = torch.cat(pred, dim=-1).view([-1, steps, self.dim][2 - size:])
        return res
    

    
    def __init_modules(self):
        modules = nn.ModuleDict()
        if self.layers > 1:
            modules['L1'] = nn.Linear(self.ind, self.width).double()
            modules['N1'] = torch.nn.Tanh()
            for i in range(2, self.layers):
                modules['L{}'.format(i)] = nn.Linear(self.width, self.width).double()
                modules['N{}'.format(i)] = torch.nn.Tanh()
            modules['Lout'] = nn.Linear(self.width, self.outd).double()
        else:
            modules['Lout'] = nn.Linear(self.ind, self.outd).double()
            
        return modules
    

def main():
    x0test=x0h_set  
    x0=x0test[0][:4].cpu().detach().numpy()
    h=x0test[0][4].cpu().detach().numpy()
    train_num =train_num_set  #20000
    test_num = test_num_set
    
    # data
    lr = lr_set
    iterations =iterations_set
    print_every = print_every_set


    u_train = dudt_solution(x0, ttrain)
    y_trainnp=u_train[1:train_num+1]
    x_trainnp=u_train[0:train_num]
    x_trainnp = np.hstack([x_trainnp, h * np.ones([x_trainnp.shape[0], 1])])

    x0_test=u_train[-1]

    u_test = dudt_solution(x0_test, ttest)
    y_testnp=u_test[1:test_num+1]
    x_testnp=u_test[0:test_num]
    x_testnp = np.hstack([x_testnp, h * np.ones([x_testnp.shape[0], 1])])



    x_train = torch.from_numpy(x_trainnp)
    y_train = torch.from_numpy(y_trainnp)
    x_test = torch.from_numpy(x_testnp)
    y_test = torch.from_numpy(y_testnp)

    # training
    Bnet=Bgrad()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Bnet.parameters(), lr=lr)


    loss_history = []
    for epoch in range(iterations):
        
        grad_pred = Bnet(y_train,x_train[..., :4])
        loss = criterion(grad_pred, (y_train-x_train[..., :4])/x_train[..., -1:])
        
        grad_pred_test = Bnet(y_test,x_test[..., :4])       
        loss_test = criterion(grad_pred_test, (y_test-x_test[..., :4])/x_test[..., -1:])
        
        if epoch %print_every ==0 or epoch+1 == iterations:
            #print(epoch, loss.item())
            print('{:<9}Train loss: {:<25}Test loss: {:<25}'.format(epoch, loss.item(), loss_test.item()), flush=True)
            loss_history.append([epoch, loss.item(), loss_test.item()])  
 
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step() 

    loss_history=np.array(loss_history)

    t1=time.time()
    flow_pred=(Bnet.predict(x_test[0].unsqueeze(0),steps=test_num-1).squeeze(0)).cpu().detach().numpy()
    t2=time.time()
    t_net=t2-t1 
 
    flow_true = x_test[..., :4].cpu().detach().numpy()

    flow_train = x_train[..., :4].cpu().detach().numpy()

    

    t1=time.time()   
    AVFpred = [x_test[0][ :4]]
    for _ in range(test_num-1):
        next=rootfinder(avffun, AVFpred[-1], params=(AVFpred[-1],h),f_tol=rootorder_set )
        AVFpred.append(next)
    AVFflow = (torch.cat(AVFpred, dim=-1).view([-1, test_num, 4][2 - 5:])).squeeze(0)
    flow_AVF =AVFflow.cpu().detach().numpy()
    t2=time.time()
    t_AVF=t2-t1

    t1=time.time()   
    AVFpred_com = [x_test[0][ :4]]
    for _ in range(test_num-1):
        next3=avf_find_4(AVFpred_com[-1],h)
        AVFpred_com.append(next3)
    AVFflow_com = (torch.cat(AVFpred_com, dim=-1).view([-1, test_num, 4][2 - 5:])).squeeze(0)
    flow_AVF_com =AVFflow_com.cpu().detach().numpy()
    t2=time.time()
    t_AVF_com=t2-t1


    t1=time.time()   
    AVFpred_2 = [x_test[0][ :4]]
    for _ in range(test_num+test_num-1):
        next=rootfinder(avffun, AVFpred_2[-1], params=(AVFpred_2[-1],0.5*h),f_tol=rootorder_set )
        AVFpred_2.append(next)
    AVFflow_2 = (torch.cat(AVFpred_2, dim=-1).view([-1, test_num+test_num, 4][2 - 5:])).squeeze(0)
    flow_AVF_2 =AVFflow_2.cpu().detach().numpy()

    t2=time.time()
    t_AVF_2=t2-t1

    #print(flow_AVF)

    AVFpred2flow=np.zeros_like(flow_AVF)
    for i in range(test_num):
        AVFpred2flow[i-1]=flow_AVF_2[2*i-2]

    #print(AVFpred2flow)


    """"""
    t_test = np.arange(0, h*test_num, h)
    flow_energy=ENERGY(flow_pred)
    true_energy=ENERGY(flow_true)
    AVF_energy=ENERGY(flow_AVF)

    en_up=true_energy[0]+0.0001
    en_down=true_energy[0]-0.0001

    #save
    #time
    print("t_net"   ,t_net) 
    print("t_AVF"   ,t_AVF) 
    print("t_AVF_2"   ,t_AVF_2) 
    print("t_AVF_com"   ,t_AVF_com) 


    plt.figure(figsize=[8, 8])    

    plt.subplot(331)
    plt.scatter(flow_train[:, 0], flow_train[:, 1], color='b', zorder=0,s=2)
    plt.title('Traindata')
    plt.legend()


    plt.subplot(332)
    plt.plot(t_test, flow_true[:, 0], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 0], color='r', label='q1', zorder=1,s=2)
    plt.scatter(t_test, flow_AVF[:, 0], color='green', label='AVF', zorder=1,s=2)
    plt.scatter(t_test, AVFpred2flow[:, 0], color='cyan', label='AVF_half', zorder=1,s=2)
    plt.scatter(t_test, flow_AVF_com[:, 0], color='orange', label='AVF_4order', zorder=1,s=2)



    plt.ylim([-1,1])
    plt.title('q1')
    #plt.legend()
    """
    plt.subplot(333)
    plt.plot(t_test, flow_true[:, 1], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 1], color='r', label='q2', zorder=1,s=2)
    plt.scatter(t_test, flow_AVF[:, 1], color='green', label='AVFq2', zorder=1,s=2)
    plt.ylim([-1,1])
    plt.title('q2')
    plt.legend()


    plt.subplot(334)
    plt.plot(flow_true[:, 2], flow_true[:, 3], color='b', label='Ground truth', zorder=0)
    plt.plot(flow_pred[:, 2], flow_pred[:, 3], color='r', label='Predicted solution', zorder=2, linewidth=0.5)
    plt.plot(flow_AVF[:, 2], flow_AVF[:, 3], color='green', label='AVF solution', zorder=1, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Henon Heiles')
    plt.legend()

    """
    plt.subplot(333)

    plt.plot(t_test, ((flow_true[:, 0]-flow_pred[:, 0])**2+(flow_true[:, 1]-flow_pred[:, 1])**2+(flow_true[:, 2]-flow_pred[:, 2])**2+(flow_true[:, 3]-flow_pred[:, 3])**2)**0.5, label='Proposed method',  color='r',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_AVF[:, 0])**2+(flow_true[:, 1]-flow_AVF[:, 1])**2+(flow_true[:, 2]-flow_AVF[:, 2])**2+(flow_true[:, 3]-flow_AVF[:, 3])**2)**0.5, label='AVF', color='green', zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-AVFpred2flow[:, 0])**2+(flow_true[:, 1]-AVFpred2flow[:, 1])**2+(flow_true[:, 2]-AVFpred2flow[:, 2])**2+(flow_true[:, 3]-AVFpred2flow[:, 3])**2)**0.5, label='AVF half stepsize',  color='cyan',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_AVF_com[:, 0])**2+(flow_true[:, 1]-flow_AVF_com[:, 1])**2+(flow_true[:, 2]-flow_AVF_com[:, 2])**2+(flow_true[:, 3]-flow_AVF_com[:, 3])**2)**0.5, label='AVF_com_4', color='orange', zorder=0)
    #plt.ylim([-1.5,2])
    plt.title('Global error')
    #plt.legend()

    """    """
    plt.subplot(336)

    plt.plot(t_test, true_energy, color='b', label='Ground truth')
    plt.plot(t_test, flow_energy, color='r', label='Proposed method')
    plt.plot(t_test, AVF_energy, color='green', label='AVF')
    plt.ylim([en_down,en_up])
    plt.title('Energy')
    plt.legend()

    """"""
    plt.subplot(337)
    plt.plot(flow_true[:, 0], flow_true[:, 1], color='b', label='Ground truth', zorder=0, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Ground truth')
    

    plt.subplot(338)
    plt.plot(flow_pred[:, 0], flow_pred[:, 1], color='r', label='Predicted solution', zorder=2, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Proposed method')


    plt.subplot(339)
    plt.plot(flow_AVF[:, 0], flow_AVF[:, 1], color='green', label='AVF solution', zorder=1, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('AVF solution')

    plt.subplot(334)
    plt.plot(flow_AVF_2[:, 0], flow_AVF_2[:, 1],  color='cyan', label='AVF half stepsize', zorder=1, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('AVF half stepsize')

    plt.subplot(335)
    plt.plot(flow_AVF_com[:, 0], flow_AVF_com[:, 1], color='orange', label='AVF com order4', zorder=1, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('AVF order4')

    plt.tight_layout() 
    plt.savefig('./henon.pdf')
    plt.show()



if __name__ == '__main__':
    main()






