import os
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from xitorch.optimize import rootfinder
from scipy.integrate import odeint
from torchsummary import summary

#2body
#Initial point setting （q1,q2,p1,p2,h）
x0h_set=torch.tensor([[-0.8, 0.,0., 1,  0.1]],dtype = torch.float64)
h=x0h_set[0][4].cpu().detach().numpy()

dataorder_set=1e-10
rootorder_set=1e-10

train_num_set =30
test_num_set = 300
ttrain = np.linspace(0, train_num_set*h, train_num_set+1)
ttest = np.linspace(0, test_num_set*h, test_num_set+1)

lr_set = 0.0001
iterations_set =10000
print_every_set = 500

layers_set=5
width_set=50

def ENERGYele(q1,q2,p1,p2):
    e=0.5*(p1**2+p2**2)-1/((q1**2+q2**2)**0.5)
    return e



def dudt(x, t):
    q1 = x[..., 0]
    q2 = x[..., 1]
    p1 = x[..., 2]
    p2 = x[..., 3]

    dHdp1 = p1
    dHdp2 = p2
    dHdq1 = q1/((q1**2+q2**2)**1.5)
    dHdq2 = q2/((q1**2+q2**2)**1.5)

    dHdp = np.hstack([dHdp1, dHdp2])
    dHdq = np.hstack([dHdq1, dHdq2])
    return np.hstack([dHdp, -dHdq])  

def dudt_solution(u0, t):
    sol = odeint(dudt, u0, t, mxstep=5000,rtol=dataorder_set,atol=dataorder_set)
    return sol

def ENERGY(xxx):
    q1 = xxx[..., 0]
    q2 = xxx[..., 1]
    p1 = xxx[..., 2]
    p2 = xxx[..., 3]
    e=ENERGYele(q1,q2,p1,p2)
    return e



s=torch.tensor([[0,0,1,0],
                [0,0,0,1],
                [-1,0,0,0],
                [0,-1,0,0],],dtype = torch.float64)


def sitograd(u,v):
    uq1 = u[..., 0]
    uq2 = u[..., 1]
    up1 = u[..., 2]
    up2 = u[..., 3]

    vq1 = v[..., 0]
    vq2 = v[..., 1]
    vp1 = v[..., 2]
    vp2 = v[..., 3]

    dq1=(ENERGYele(uq1,vq2,vp1,vp2)-ENERGYele(vq1,vq2,vp1,vp2))/(uq1-vq1)
    dq2=(ENERGYele(uq1,uq2,vp1,vp2)-ENERGYele(uq1,vq2,vp1,vp2))/(uq2-vq2)
    dp1=(ENERGYele(uq1,uq2,up1,vp2)-ENERGYele(uq1,uq2,vp1,vp2))/(up1-vp1)
    dp2=(ENERGYele(uq1,uq2,up1,up2)-ENERGYele(uq1,uq2,up1,vp2))/(up2-vp2)
    
    return torch.vstack([dp1,dp2,-dq1,-dq2]).T


def symsitograd(u,v):
    return (sitograd(u,v)+sitograd(v,u))/2


def itofun(u,v,h):
    return symsitograd(u,v)*h-(u-v)


def ito_find_4(v,h):
    r1 =   1 / (2 - 2 ** (1 / 3))
    r2 =  - 2 ** (1 / 3) / (2 - 2 ** (1 / 3))
    next=rootfinder(itofun, torch.tensor([10,10,10,10],dtype = torch.float64), params=(v,r1*h),f_tol=rootorder_set*1e-2 )
    next2=rootfinder(itofun, torch.tensor([10,10,10,10],dtype = torch.float64), params=(next,r2*h),f_tol=rootorder_set*1e-2 )
    next3=rootfinder(itofun, torch.tensor([10,10,10,10],dtype = torch.float64), params=(next2,r1*h),f_tol=rootorder_set*1e-2 )
    return next3




class Bgrad(torch.nn.Module):

    def __init__(self, dim=4,layers=layers_set,ind=8, outd=16, width=width_set, initializer='default'):
        super(Bgrad, self).__init__()
        self.dim = dim
        self.ind = ind
        self.outd = outd
        self.layers = layers
        self.width = width
        self.initializer = initializer

        self.modus = self.__init_modules()


    def forward(self,u,v):
        sito=symsitograd(u,v)
        x0x1=torch.hstack([u,v])
        x0x1=x0x1.unsqueeze(1)
        x0=v.unsqueeze(1)
        x1=u.unsqueeze(1)
        x0T=x0.permute(0,2,1)
        x1T=x1.permute(0,2,1)
        
        x=x0x1
        for i in range(1, self.layers):
            LinM = self.modus['L{}'.format(i)]
            NonM = self.modus['N{}'.format(i)]
            x = NonM(LinM(x))
        x = self.modus['Lout'](x)
        
        #print(x)
        n=len(x)
        #Make a skew symmetric matrix
        A=torch.reshape(x,(n,self.dim,self.dim))
        AT=A.permute(0,2,1)
        AAA=A-AT
        #print(AAA)

        #print(torch.matmul(s,AAA))
        #print(torch.bmm(torch.matmul(s,AAA),(x1T-x0T)))

        sBgrad=sito+torch.bmm(torch.matmul(s,AAA),(x1T-x0T)).squeeze(2)
        #print('sito',sito)
        #print('sBgrad',sBgrad)
        return sBgrad
    
    def Bfun(self,u,v,h):
        return self.forward(u,v)*h-(u-v)
    
    def predict(self,xh,steps=1):
        size = len(xh.size())
        x0, h = xh[..., :-1], xh[..., -1:]
        #print(x0)
        #print(h)
        pred = [x0]
        #print(pred)

        for _ in range(steps):
            next=rootfinder(self.Bfun, torch.tensor([[10,10,10,10]],dtype = torch.float64), params=(pred[-1],h),f_tol=rootorder_set)
            pred.append(next)


        steps = steps + 1

        res = torch.cat(pred, dim=-1).view([-1, steps, self.dim][2 - size:])
        return res
    

    
    
    def __init_modules(self):
        modules = nn.ModuleDict()
        if self.layers > 1:
            modules['L1'] = nn.Linear(self.ind, self.width).double()
            modules['N1'] = torch.nn.Tanh()
            for i in range(2, self.layers):
                modules['L{}'.format(i)] = nn.Linear(self.width, self.width).double()
                modules['N{}'.format(i)] = torch.nn.Tanh()
            modules['Lout'] = nn.Linear(self.width, self.outd).double()
        else:
            modules['Lout'] = nn.Linear(self.ind, self.outd).double()
            
        return modules
    


def main():
    device = 'cpu' # 'cpu' or 'gpu'
    # data
    #x0 = [-0.2, -0.2, 0.2, 0.2]  # [-0.1, -0.1, 0., 0.3]en0.046   [-0.2, -0.1, 0.3, 0.3][-0.2, -0.2, 0.3, 0.3]
    #h = 0.4 #0.03 0.38 0.55 0.7
    x0test=x0h_set  #-2.87e-01 -1.56e-01 2.50e-01 2.74e-01 2.0e-01
    x0=x0test[0][:4].cpu().detach().numpy()
    h=x0test[0][4].cpu().detach().numpy()
    train_num =train_num_set  #20000
    test_num = test_num_set
    # net
    



    lr = lr_set
    iterations =iterations_set
    print_every = print_every_set


    #data
    u_train = dudt_solution(x0, ttrain)
    y_trainnp=u_train[1:train_num+1]
    x_trainnp=u_train[0:train_num]
    x_trainnp = np.hstack([x_trainnp, h * np.ones([x_trainnp.shape[0], 1])])

    x0_test=u_train[-1]

    u_test = dudt_solution(x0_test, ttest)
    y_testnp=u_test[1:test_num+1]
    x_testnp=u_test[0:test_num]
    x_testnp = np.hstack([x_testnp, h * np.ones([x_testnp.shape[0], 1])])



    x_train = torch.from_numpy(x_trainnp)
    y_train = torch.from_numpy(y_trainnp)
    x_test = torch.from_numpy(x_testnp)
    y_test = torch.from_numpy(y_testnp)

    # training
    Bnet=Bgrad()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Bnet.parameters(), lr=lr)


    loss_history = []
    for epoch in range(iterations):
        
        grad_pred = Bnet(y_train,x_train[..., :4])
        loss = criterion(grad_pred, (y_train-x_train[..., :4])/x_train[..., -1:])
        
        grad_pred_test = Bnet(y_test,x_test[..., :4])       
        loss_test = criterion(grad_pred_test, (y_test-x_test[..., :4])/x_test[..., -1:])
        
        if epoch %print_every ==0 or epoch+1 == iterations:
            #print(epoch, loss.item())
            print('{:<9}Train loss: {:<25}Test loss: {:<25}'.format(epoch, loss.item(), loss_test.item()), flush=True)
            loss_history.append([epoch, loss.item(), loss_test.item()])  
 
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step() 

    loss_history=np.array(loss_history)


    t1=time.time()
    flow_pred=(Bnet.predict(x_test[0].unsqueeze(0),steps=test_num-1).squeeze(0)).cpu().detach().numpy()
    t2=time.time()
    t_net=t2-t1  
    
    
    flow_true = x_test[..., :4].cpu().detach().numpy()

    flow_train = x_train[..., :4].cpu().detach().numpy()


    
    t1=time.time()   
    ITOHpred = [x_test[0][ :4]]
    for _ in range(test_num-1):
        next=rootfinder(itofun, torch.tensor([10,10,10,10],dtype = torch.float64), params=(ITOHpred[-1],h),f_tol=rootorder_set )
        ITOHpred.append(next)
    ITOHflow = (torch.cat(ITOHpred, dim=-1).view([-1, test_num, 4][2 - 5:])).squeeze(0)
    flow_ITOH =ITOHflow.cpu().detach().numpy()
    t2=time.time()
    t_ITOH=t2-t1

    t1=time.time()   
    ITOHpred_com = [x_test[0][ :4]]
    for _ in range(test_num-1):
        next3=ito_find_4(ITOHpred_com[-1],h)
        ITOHpred_com.append(next3)
    ITOHflow_com = (torch.cat(ITOHpred_com, dim=-1).view([-1, test_num, 4][2 - 5:])).squeeze(0)
    flow_ITOH_com =ITOHflow_com.cpu().detach().numpy()
    t2=time.time()
    t_ITOH_com=t2-t1


    t1=time.time()   
    ITOHpred_2 = [x_test[0][ :4]]
    for _ in range(test_num+test_num-1):
        next=rootfinder(itofun, torch.tensor([10,10,10,10],dtype = torch.float64), params=(ITOHpred_2[-1],0.5*h),f_tol=rootorder_set )
        ITOHpred_2.append(next)
    ITOHflow_2 = (torch.cat(ITOHpred_2, dim=-1).view([-1, test_num+test_num, 4][2 - 5:])).squeeze(0)
    flow_ITOH_2 =ITOHflow_2.cpu().detach().numpy()

    t2=time.time()
    t_ITOH_2=t2-t1    

    ITOHpred2flow=np.zeros_like(flow_ITOH)
    for i in range(test_num):
        ITOHpred2flow[i-1]=flow_ITOH_2[2*i-2]

    """"""
    t_test = np.arange(0, h*test_num, h)
    flow_energy=ENERGY(flow_pred)
    true_energy=ENERGY(flow_true)
    ITOH_energy=ENERGY(flow_ITOH)


    en_up=true_energy[0]+0.0001
    en_down=true_energy[0]-0.0001
    #save

    print("t_net"   ,t_net) 
    print("t_ITOH"   ,t_ITOH) 
    print("t_ITOH_2"   ,t_ITOH_2) 
    print("t_ITOH_com"   ,t_ITOH_com) 

    

    plt.figure(figsize=[8, 8])    

    plt.subplot(331)
    plt.scatter(flow_train[:, 0], flow_train[:, 1], color='b', zorder=0,s=2)
    plt.title('Traindata')
    plt.legend()


    plt.subplot(332)
    plt.plot(t_test, flow_true[:, 0], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 0], color='r', label='q1', zorder=1,s=2)
    plt.scatter(t_test, flow_ITOH[:, 0], color='green', label='ITOH', zorder=1,s=2)
    plt.scatter(t_test, ITOHpred2flow[:, 0], color='cyan', label='ITOH_half', zorder=1,s=2)
    plt.scatter(t_test, flow_ITOH_com[:, 0], color='orange', label='ITOH_4order', zorder=1,s=2)
 

    plt.ylim([-1,1])
    plt.title('q1')
    #plt.legend()
    """
    plt.subplot(333)
    plt.plot(t_test, flow_true[:, 1], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 1], color='r', label='q2', zorder=1,s=2)
    plt.scatter(t_test, flow_ITOH[:, 1], color='green', label='ITOHq2', zorder=1,s=2)
    plt.ylim([-1,1])
    plt.title('q2')
    plt.legend()


    plt.subplot(334)
    plt.plot(flow_true[:, 2], flow_true[:, 3], color='b', label='Ground truth', zorder=0)
    plt.plot(flow_pred[:, 2], flow_pred[:, 3], color='r', label='Predicted solution', zorder=2, linewidth=0.5)
    plt.plot(flow_ITOH[:, 2], flow_ITOH[:, 3], color='green', label='ITOH solution', zorder=1, linewidth=0.5)
    plt.ylim([-0.5,0.75])
    plt.xlim([-0.7,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('2body')
    plt.legend()

    """
    plt.subplot(333)

    plt.plot(t_test, ((flow_true[:, 0]-flow_pred[:, 0])**2+(flow_true[:, 1]-flow_pred[:, 1])**2+(flow_true[:, 2]-flow_pred[:, 2])**2+(flow_true[:, 3]-flow_pred[:, 3])**2)**0.5, label='Proposed method',  color='r',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_ITOH[:, 0])**2+(flow_true[:, 1]-flow_ITOH[:, 1])**2+(flow_true[:, 2]-flow_ITOH[:, 2])**2+(flow_true[:, 3]-flow_ITOH[:, 3])**2)**0.5, label='ITOH', color='green', zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-ITOHpred2flow[:, 0])**2+(flow_true[:, 1]-ITOHpred2flow[:, 1])**2+(flow_true[:, 2]-ITOHpred2flow[:, 2])**2+(flow_true[:, 3]-ITOHpred2flow[:, 3])**2)**0.5, label='ITOH half stepsize',  color='cyan',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_ITOH_com[:, 0])**2+(flow_true[:, 1]-flow_ITOH_com[:, 1])**2+(flow_true[:, 2]-flow_ITOH_com[:, 2])**2+(flow_true[:, 3]-flow_ITOH_com[:, 3])**2)**0.5, label='ITOH_com_4', color='orange', zorder=0)
    #plt.ylim([-1.5,2])
    plt.title('Global error')
    #plt.legend()

    
    plt.subplot(336)

    plt.plot(t_test, true_energy, color='b', label='Ground truth')
    plt.plot(t_test, flow_energy, color='r', label='Proposed method')
    plt.plot(t_test, ITOH_energy, color='green', label='Itoh-Abe')
    plt.ylim([en_down,en_up])
    plt.title('Energy')
    plt.legend()
    """    """

    plt.subplot(337)
    plt.plot(flow_true[:, 0], flow_true[:, 1], color='b', label='Ground truth', zorder=0, linewidth=0.5)
    plt.ylim([-0.7,0.9])
    plt.xlim([-0.9,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Ground truth')


    plt.subplot(338)
    plt.plot(flow_pred[:, 0], flow_pred[:, 1], color='r', label='Predicted solution', zorder=2, linewidth=0.5)
    plt.ylim([-0.7,0.9])
    plt.xlim([-0.9,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Proposed method')


    plt.subplot(339)
    plt.plot(flow_ITOH[:, 0], flow_ITOH[:, 1], color='green', label='ITOH solution', zorder=1, linewidth=0.5)
    plt.ylim([-0.7,0.9])
    plt.xlim([-0.9,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Itoh-Abe solution')

    plt.subplot(334)
    plt.plot(flow_ITOH_2[:, 0], flow_ITOH_2[:, 1],  color='cyan', label='ITOH half stepsize', zorder=1, linewidth=0.5)
    plt.ylim([-0.7,0.9])
    plt.xlim([-0.9,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Itoh-Abe half stepsize')

    plt.subplot(335)
    plt.plot(flow_ITOH_com[:, 0], flow_ITOH_com[:, 1], color='orange', label='ITOH com order4', zorder=1, linewidth=0.5)
    plt.ylim([-0.7,0.9])
    plt.xlim([-0.9,0.7])
    plt.xlabel("q1")
    plt.ylabel("q2")
    plt.title('Itoh-Abe order4')


    plt.tight_layout() 



    plt.savefig('./2body.pdf')

    plt.show()



if __name__ == '__main__':
    main()






