
import scipy as sci

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation

from TorchDiffEqPack.odesolver import odesolve
from TorchDiffEqPack import odesolve_adjoint_sym12
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import pandas as pd
import shutil
from skopt.space import Space
from skopt.sampler import Halton
import random 





lr = 1e-3


TrainMode = True


time_span = np.linspace(0, 0.5, 2000) 

# Define constant
a1 = 2.0
a2 = 0.7
b1 = -0.4
b2 = 3.0
e = 1.0






x0 = [0.8,  -0.4, 0.0, 0.0]
x0 = np.array(x0, dtype="float64")




space = Space([(x0[0]-1.0, x0[0]+1.0), (x0[1]-1.0, x0[1]+1.0), (x0[2]-1.0, x0[2]+1.0), (x0[3]-1.0, x0[3]+1.0)])

n_samples = 1000
halton = Halton()
start = halton.generate(space.dimensions, n_samples)
start_n = np.array(start, dtype="float64")
initial_data = torch.from_numpy(np.array(start_n)).float()


def DuffingEquations(w, t, a1, a2, b1, b2, e):
    q = w[...,:2]
    p = w[...,2:4]
   

    
    dqdt = p
    dp1dt = - a1*q[0] - b1*q[0]**3 + e*(q[1] - q[0])   
    dp2dt = - a2*q[1] - b2*q[1]**3 - e*(q[1] - q[0])   
    
    
    
    dpdt = np.array([dp1dt, dp2dt])
    
   
    derivs = np.concatenate((dqdt, dpdt))
    return derivs


import scipy.integrate



q01 = np.zeros((n_samples,2000,4))


for x1 in range(n_samples):
    init_params = start[x1]
    init_params = np.array(init_params, dtype="float64")
    init_params = init_params.flatten()
    
    three_body_sol01 = sci.integrate.odeint(DuffingEquations, init_params, time_span, args=(a1,a2,b1,b2,e), rtol=1e-13, atol=1e-14)

    q01[x1] = three_body_sol01[:,0:4]
    
    
traj_q01 = torch.from_numpy(q01).float()
traj_q01 = torch.unsqueeze(traj_q01, 0)


       

        

class TorchDuffingEquations(nn.Module):
    def __init__(self):
        super(TorchDuffingEquations, self).__init__()
        
        
        self.nout = 100
        
        self.dV1 = nn.Sequential(
        nn.Linear(1, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, 1))
        
        self.dV2 = nn.Sequential(
        nn.Linear(1, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, 1))
        
        self.dV3 = nn.Sequential(
        nn.Linear(1, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, 1))
        
        self.dV4 = nn.Sequential(
        nn.Linear(1, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, 1))
        
        self.dV5 = nn.Sequential(
        nn.Linear(1, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, self.nout),
        nn.Tanh(),
        nn.Linear(self.nout, 1))
        
    
    def getdV1(self, w):
        
        return self.dV1(torch.reshape(w,(1,)))
    def getdV2(self, w):
        
        return self.dV2(torch.reshape(w,(1,)))
    def getdV3(self, w):
        
        return self.dV3(torch.reshape(w,(1,)))
    def getdV4(self, w):
        
        return self.dV4(torch.reshape(w,(1,)))
    def getdV5(self, w):
        
        return self.dV5(torch.reshape(w,(1,)))
    

    
    
    def forward(self, t, w):
        
        q = w[...,:2]
        p = w[...,2:4]
        
        A=torch.ones(1, 2)
        A[0,0] =-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,1] = -self.dV3(torch.reshape(q[1],(1,))) - self.dV4(torch.reshape(q[1],(1,))) - self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        
        
        dqdt = p
        dpdt = torch.reshape(A,(-1,))
        
        
        derivs = torch.cat((dqdt,dpdt), dim=-1)
        return derivs



func = TorchDuffingEquations()
t_list = time_span.tolist()



options = {}
options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async suzuki_alf2
options.update({'h': 0.1})
options.update({'t0': 0.0})
options.update({'t1': 0.5})
options.update({'rtol': 1e-4})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})
options.update({'t_eval':None})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':True})
optimizer = torch.optim.AdamW(func.parameters(), lr=lr, betas=(0.50, 0.50), eps=1e-08, weight_decay=0.01, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)


torch.manual_seed(1229)
random.seed(1230)
np.random.seed(1234)

func.train()


NbTraj = 300

best_loss = np.inf
import time

loss = 10.0
i = 0

TrainLoss = np.zeros((300))
TestlossVF = np.zeros((300))
TestlossV =np.zeros((300)) 





if TrainMode:
    start_time = time.time()
    while loss > 1e-2 and i<300:
        i = i+1
        
        
  
        func.eval()
        
        x1 = random.randint(0,n_samples-1)
        
        out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)
        
        
        
        position01 = out01[..., :4]
        dif01 = position01 - traj_q01[0,x1,-1]
        dif01 = torch.sum(dif01 ** 2, -1, keepdim=False) 
        
        dif101 = torch.squeeze(dif01) 
        l =torch.sum(torch.abs(dif101))
        loss = torch.norm(l).item()
        
        for k in range(NbTraj-1):
            x1 = random.randint(0,n_samples-1)
            
            
            out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)
            
            
            
            position01 = out01[..., :4]
            dif01 = position01 - traj_q01[0,x1,-1]
            dif01 = torch.sum(dif01 ** 2, -1, keepdim=False)  
            
            dif101 = torch.squeeze(dif01) 
            l =l + torch.sum(torch.abs(dif101))
            loss = loss + torch.norm(l).item()
            
            
        
        
        
        
        
        l = l/float(NbTraj)
        
        

        
        l.backward()

        optimizer.step()
        
        print('Epoch %d: Loss: %.8f' % (i, l.item()))
    print('Finished training')
    print("--- %s seconds ---" % (time.time() - start_time))


