import os
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt
import torch

from xitorch.optimize import rootfinder
from scipy.integrate import odeint
#pendulum
#setting
#Initial point setting （q1,p1,h）
x0h_set=torch.tensor([[2, 0, 0.5]],dtype = torch.float64)
h=x0h_set[0][2].cpu().detach().numpy()
train_num_set =15
test_num_set = 200
ttrain = np.linspace(0, train_num_set*h, train_num_set+1)
ttest = np.linspace(0, test_num_set*h, test_num_set+1)


dataorder_set=1e-12
rootorder_set=1e-12

lr_set = 0.0001
iterations_set =10000
print_every_set = 500

layers_set=5
width_set=50



def ENERGY(xxx):
    q1 = xxx[..., 0]
    p1 = xxx[..., 1]
    e=0.5*(p1**2)-np.cos(q1)
    return e


s=torch.tensor([[0,1],
                [-1,0]],dtype = torch.float64)

def SAVFgrad(u,v):
    uq1 = u[..., 0]
    up1 = u[..., 1]
    vq1 = v[..., 0]
    vp1 = v[..., 1]
    dq1=(-torch.cos(uq1)+torch.cos(vq1))/(uq1-vq1)
    dp1=up1/2 + vp1/2
    return torch.vstack([dp1,-dq1]).T

def dudt(x, t):
    q1 = x[..., 0]
    p1 = x[..., 1]
    dHdp1 = p1
    dHdq1 = np.sin(q1)

    return np.hstack([dHdp1, -dHdq1])  

def dudt_solution(u0, t):
    sol = odeint(dudt, u0, t, mxstep=5000,rtol=dataorder_set,atol=dataorder_set)
    return sol

def avffun(u,v,h):
    return SAVFgrad(u,v)*h-(u-v)


def avf_find_4(v,h):
    r1 = 1 / (2 - 2 ** (1 / 3))
    r2 = - 2 ** (1 / 3) / (2 - 2 ** (1 / 3))
    next=rootfinder(avffun, torch.tensor([0,0],dtype = torch.float64), params=(v,r1*h),f_tol=rootorder_set*1e-2 )
    next2=rootfinder(avffun,torch.tensor([0,0],dtype = torch.float64), params=(next,r2*h),f_tol=rootorder_set*1e-2 )
    next3=rootfinder(avffun, torch.tensor([0,0],dtype = torch.float64), params=(next2,r1*h),f_tol=rootorder_set*1e-2 )
    return next3



class Bgrad(torch.nn.Module):

    def __init__(self, dim=2,layers=layers_set,ind=4, outd=4, width=width_set, initializer='default'):
        super(Bgrad, self).__init__()
        self.dim = dim
        self.ind = ind
        self.outd = outd
        self.layers = layers
        self.width = width
        self.initializer = initializer

        self.modus = self.__init_modules()


    def forward(self,u,v):
        savf=SAVFgrad(u,v)
        x0x1=torch.hstack([u,v])
        x0x1=x0x1.unsqueeze(1)
        x0=v.unsqueeze(1)
        x1=u.unsqueeze(1)
        x0T=x0.permute(0,2,1)
        x1T=x1.permute(0,2,1)
        
        x=x0x1
        for i in range(1, self.layers):
            LinM = self.modus['L{}'.format(i)]
            NonM = self.modus['N{}'.format(i)]
            x = NonM(LinM(x))
        x = self.modus['Lout'](x)
        
        #print(x)
        n=len(x)
        #Make a skew symmetric matrix
        A=torch.reshape(x,(n,self.dim,self.dim))
        AT=A.permute(0,2,1)
        AAA=A-AT
        #print(AAA)

        #print(torch.matmul(s,AAA))
        #print(torch.bmm(torch.matmul(s,AAA),(x1T-x0T)))

        sBgrad=savf+torch.bmm(torch.matmul(s,AAA),(x1T-x0T)).squeeze(2)
        #print('savf',savf)
        #print('sBgrad',sBgrad)
        return sBgrad
    
    def Bfun(self,u,v,h):
        return self.forward(u,v)*h-(u-v)
    
    def predict(self,xh,steps=1):
        #Generating trajectories
        size = len(xh.size())
        x0, h = xh[..., :-1], xh[..., -1:]
        #print(x0)
        #print(h)
        pred = [x0]
        #print(pred)

        for _ in range(steps):
            next=rootfinder(self.Bfun,  torch.tensor([[0,0]],dtype = torch.float64), params=(pred[-1],h),f_tol=rootorder_set )
            pred.append(next)

        steps = steps + 1

        res = torch.cat(pred, dim=-1).view([-1, steps, self.dim][2 - size:])
        return res
    

    
    def __init_modules(self):
        modules = nn.ModuleDict()
        if self.layers > 1:
            modules['L1'] = nn.Linear(self.ind, self.width).double()
            modules['N1'] = torch.nn.Tanh()
            for i in range(2, self.layers):
                modules['L{}'.format(i)] = nn.Linear(self.width, self.width).double()
                modules['N{}'.format(i)] = torch.nn.Tanh()
            modules['Lout'] = nn.Linear(self.width, self.outd).double()
        else:
            modules['Lout'] = nn.Linear(self.ind, self.outd).double()
            
        return modules


def main():
    device = 'cpu' # 'cpu' or 'gpu'
    # data
    #x0 = [-0.2, -0.2, 0.2, 0.2]  # [-0.1, -0.1, 0., 0.3]en0.046   [-0.2, -0.1, 0.3, 0.3][-0.2, -0.2, 0.3, 0.3]
    #h = 0.4 #0.03 0.38 0.55 0.7
    x0test=x0h_set  #-2.87e-01 -1.56e-01 2.50e-01 2.74e-01 2.0e-01
    x0=x0test[0][:2].cpu().detach().numpy()
    h=x0test[0][2].cpu().detach().numpy()
    train_num =train_num_set  #20000
    test_num = test_num_set
    # net
    

    #data

    lr = lr_set
    iterations =iterations_set
    print_every = print_every_set

    u_train = dudt_solution(x0, ttrain)
    y_trainnp=u_train[1:train_num+1]
    x_trainnp=u_train[0:train_num]
    x_trainnp = np.hstack([x_trainnp, h * np.ones([x_trainnp.shape[0], 1])])

    x0_test=u_train[-1]

    u_test = dudt_solution(x0_test, ttest)
    y_testnp=u_test[1:test_num+1]
    x_testnp=u_test[0:test_num]
    x_testnp = np.hstack([x_testnp, h * np.ones([x_testnp.shape[0], 1])])



    x_train = torch.from_numpy(x_trainnp)
    y_train = torch.from_numpy(y_trainnp)
    x_test = torch.from_numpy(x_testnp)
    y_test = torch.from_numpy(y_testnp)

    # training
    Bnet=Bgrad()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Bnet.parameters(), lr=lr)

    loss_history = []
    for epoch in range(iterations):
        
        grad_pred = Bnet(y_train,x_train[..., :2])
        loss = criterion(grad_pred, (y_train-x_train[..., :2])/x_train[..., -1:])
        
        grad_pred_test = Bnet(y_test,x_test[..., :2])       
        loss_test = criterion(grad_pred_test, (y_test-x_test[..., :2])/x_test[..., -1:])
        
        if epoch %print_every ==0 or epoch+1 == iterations:
            #print(epoch, loss.item())
            print('{:<9}Train loss: {:<25}Test loss: {:<25}'.format(epoch, loss.item(), loss_test.item()), flush=True)
            loss_history.append([epoch, loss.item(), loss_test.item()])  
 
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step() 



    flow_pred=(Bnet.predict(x_test[0].unsqueeze(0),steps=test_num-1).squeeze(0)).cpu().detach().numpy()
    flow_true = x_test[..., :2].cpu().detach().numpy()

    flow_train = x_train[..., :2].cpu().detach().numpy()




    def avffun(u,v,h):
        return SAVFgrad(u,v)*h-(u-v)
    
    t1=time.time()   
    AVFpred = [x_test[0][ :2]]
    for _ in range(test_num-1):
        next=rootfinder(avffun, torch.from_numpy(flow_true[_+1]), params=(AVFpred[-1],h),f_tol=rootorder_set )
        AVFpred.append(next)
    AVFflow = (torch.cat(AVFpred, dim=-1).view([-1, test_num, 2])).squeeze(0)
    flow_AVF =AVFflow.cpu().detach().numpy()
    t2=time.time()
    t_AVF=t2-t1

    t1=time.time()   
    AVFpred_com = [x_test[0][ :2]]
    for _ in range(test_num-1):
        next3=avf_find_4(AVFpred_com[-1],h)
        AVFpred_com.append(next3)
    AVFflow_com = (torch.cat(AVFpred_com, dim=-1).view([-1, test_num, 2])).squeeze(0)
    flow_AVF_com =AVFflow_com.cpu().detach().numpy()
    t2=time.time()
    t_AVF_com=t2-t1



    t1=time.time()   
    AVFpred_2 = [x_test[0][ :2]]
    for _ in range(test_num+test_num-1):
        next=rootfinder(avffun, torch.tensor([0,0],dtype = torch.float64), params=(AVFpred_2[-1],0.5*h),f_tol=rootorder_set )
        AVFpred_2.append(next)
    AVFflow_2 = (torch.cat(AVFpred_2, dim=-1).view([-1, test_num+test_num, 2])).squeeze(0)
    flow_AVF_2 =AVFflow_2.cpu().detach().numpy()

    t2=time.time()
    t_AVF_2=t2-t1

    #print(flow_AVF)

    AVFpred2flow=np.zeros_like(flow_AVF)
    for i in range(test_num):
        AVFpred2flow[i-1]=flow_AVF_2[2*i-2]

    #print(AVFpred2flow)



    """"""
    #作图
    t_test = np.arange(0, h*test_num, h)
    flow_energy=ENERGY(flow_pred)
    true_energy=ENERGY(flow_true)
    AVF_energy=ENERGY(flow_AVF)

    en_up=true_energy[0]+0.0005
    en_down=true_energy[0]-0.0005  



    plt.figure(figsize=[8, 8 ])      #[5, 18 ]

    plt.subplot(321)
    plt.scatter(flow_train[:, 0], flow_train[:, 1], color='b', zorder=0,s=8)
    plt.xlabel("q")
    plt.ylabel("p")
    plt.title('Traindata')
    plt.legend()


    plt.subplot(322)
    plt.plot(flow_true[:, 0], flow_true[:, 1], color='b', label='Ground truth', zorder=0)
    plt.scatter(flow_pred[:, 0], flow_pred[:, 1], color='r', label='Proposed method', zorder=1,s=5)
    plt.scatter(flow_AVF[:, 0], flow_AVF[:, 1], color='green', label='AVF', zorder=2,s=5)

    plt.xlabel("q")
    plt.ylabel("p")
    plt.title('Pendlum')
    plt.legend()


    plt.subplot(323)
    plt.plot(t_test, flow_true[:, 0], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 0], color='r', label='Proposed method', zorder=2,s=8)
    plt.scatter(t_test, flow_AVF[:, 0], color='green', label='AVF', zorder=1,s=8)
    plt.scatter(t_test, AVFpred2flow[:, 0], color='cyan', label='AVF half stepsize', zorder=1,s=8)
    plt.scatter(t_test, flow_AVF_com[:, 0], color='orange', label='AVF order4', zorder=1,s=8)
 

    plt.title('q')
    plt.legend()

    plt.subplot(324)
    plt.plot(t_test, flow_true[:, 1], color='b', label='Ground truth', zorder=0)
    plt.scatter(t_test, flow_pred[:, 1], color='r', label='Proposed method', zorder=2,s=8)
    plt.scatter(t_test, flow_AVF[:, 1], color='green', label='AVF', zorder=1,s=8)
    plt.scatter(t_test, AVFpred2flow[:, 1], color='cyan', label='AVF half stepsize', zorder=1,s=8)
    plt.scatter(t_test, flow_AVF_com[:, 1], color='orange', label='AVF order4', zorder=1,s=8)
    plt.title('p')
    plt.legend()




    plt.subplot(325)

    plt.plot(t_test, ((flow_true[:, 0]-flow_pred[:, 0])**2+(flow_true[:, 1]-flow_pred[:, 1])**2)**0.5, label='Proposed method', color='r',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_AVF[:, 0])**2+(flow_true[:, 1]-flow_AVF[:, 1])**2)**0.5, label='AVF', color='green',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-AVFpred2flow[:, 0])**2+(flow_true[:, 1]-AVFpred2flow[:, 1])**2)**0.5, label='AVF half stepsize',  color='cyan',zorder=0)
    plt.plot(t_test, ((flow_true[:, 0]-flow_AVF_com[:, 0])**2+(flow_true[:, 1]-flow_AVF_com[:, 1])**2)**0.5, label='AVF order4', color='orange', zorder=0)
    #plt.ylim([-1.5,2])
    plt.title('Global error')
    plt.legend()

    plt.subplot(326)
    plt.plot(t_test, true_energy, color='b', label='Ground truth')
    plt.plot(t_test, flow_energy, color='r', label='Proposed method')
    plt.plot(t_test, AVF_energy, color='green', label='AVF')
    plt.ylim([en_down,en_up])
    plt.title('Energy')
    plt.legend()
    plt.tight_layout() 

    plt.savefig('./pendulum.pdf')


    plt.show()



if __name__ == '__main__':
    main()






