
import scipy as sci

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation

from TorchDiffEqPack.odesolver import odesolve
from TorchDiffEqPack import odesolve_adjoint_sym12
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import pandas as pd
import shutil



## Kepler
a = np.pi/4




x0 = [0.75,  0, 0, 0.9*(np.pi/4)*np.sqrt(5/3)]
x0 = np.array(x0, dtype="float64")

IniD = np.zeros((3,3,3,3,4))
InitialData = IniD.tolist()
IniStep = 0.1

for x1 in range(-1,2):
    for x2 in range(-1,2):
        for x3 in range(-1,2):
            for x4 in range(-1,2):
            
                i1 = [IniStep, 0.0, 0.0, 0.0] 
                i2 = [0.0, IniStep, 0.0, 0.0]
                i3 = [0.0, 0.0, IniStep, 0.0]
                i4 = [0.0, 0.0, 0.0, IniStep]
                
                InitialData[1+x1][1+x2][1+x3][1+x4]= [_x0 + x1*_i1 + x2*_i2 + x3*_i3 + x4*_i4 for _x0,_i1,_i2,_i3,_i4 in zip(x0,i1,i2,i3,i4)]



def KeplerEquations(w, t, a):
    q = w[...,:2]
    p = w[...,2:4]
    

    q12 = sci.linalg.norm(q)  
    dqdt = p
    dpdt = (-a/q12**(3))*q
    
    
    
   
    derivs = np.concatenate((dqdt, dpdt))
    return derivs

import scipy.integrate

q02 = np.zeros((3,3,3,3,2))
q04 = np.zeros((3,3,3,3,2))
q06 = np.zeros((3,3,3,3,2))
q08 = np.zeros((3,3,3,3,2))
q1 = np.zeros((3,3,3,3,2))
for x1 in range(-1,2):
    for x2 in range(-1,2):
        for x3 in range(-1,2):
            for x4 in range(-1,2):
                
                init_params = np.array(InitialData[1+x1][1+x2][1+x3][1+x4], dtype="float64")
                init_params = init_params.flatten()  
                
                time_span = np.linspace(0.0, 0.2, 2000) 
                sol02 = sci.integrate.odeint(KeplerEquations, init_params, time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
                q02[1+x1,1+x2,1+x3,1+x4] = sol02[-1,0:2]
                                

                                
                time_span = np.linspace(0.2, 0.4, 2000) 
                sol04 = sci.integrate.odeint(KeplerEquations, sol02[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
                q04[1+x1,1+x2,1+x3,1+x4] = sol04[-1,0:2]
                                
                                
                time_span = np.linspace(0.4, 0.6, 2000)

                sol06 = sci.integrate.odeint(KeplerEquations, sol04[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
                q06[1+x1,1+x2,1+x3,1+x4] = sol06[-1,0:2]
                                
                time_span = np.linspace(0.6, 0.8, 2000) 

                sol08 = sci.integrate.odeint(KeplerEquations, sol06[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
                q08[1+x1,1+x2,1+x3,1+x4] = sol08[-1,0:2]
                                
                time_span = np.linspace(0.8, 1.0, 2000) 

                sol1 = sci.integrate.odeint(KeplerEquations, sol08[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
                q1[1+x1,1+x2,1+x3,1+x4] = sol1[-1,0:2]


traj_q02 = torch.from_numpy(q02).float()
traj_q02 = torch.unsqueeze(traj_q02, 0)


traj_q04 = torch.from_numpy(q04).float()
traj_q04 = torch.unsqueeze(traj_q04, 0) 

traj_q06 = torch.from_numpy(q06).float()
traj_q06 = torch.unsqueeze(traj_q06, 0) 

traj_q08 = torch.from_numpy(q08).float()
traj_q08 = torch.unsqueeze(traj_q08, 0) 

traj_q1 = torch.from_numpy(q1).float()
traj_q1 = torch.unsqueeze(traj_q1, 0) 


InitialData = np.array(InitialData, dtype="float64")
initial_condition = torch.from_numpy(InitialData).float()
initial_condition = torch.unsqueeze(initial_condition, 0)



class TorchKeplerEquations(nn.Module):
    def __init__(self):
        super(TorchKeplerEquations, self).__init__()
        
        self.a = nn.Parameter(torch.ones(1)*0.1)
        

    def forward(self, t, w):
        
        q = w[...,:2]
        p = w[...,2:4]

        
        q12 = torch.norm(q)  
        dqdt = p
        dpdt = (-self.a/q12**(3))*q

        
     
        derivs = torch.cat((dqdt,dpdt), dim=-1)
        return derivs






func = TorchKeplerEquations()





options = {}

options.update({'h': 0.01})

options.update({'rtol': 1e-4})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':False})


valuesALF =  np.zeros((300))
valuesYoshida = np.zeros((300))
steps = np.zeros((300))




l = 0.0


for k in range(300):
    options.update({'method': 'fixedstep_sym12async'})
    
    func.a = torch.nn.Parameter(torch.ones(1)*(np.pi/4 - (150-k)*0.000001))
    
    
    for x1 in range(-1,2):
        for x2 in range(-1,2):
            for x3 in range(-1,2):
                for x4 in range(-1,2):
                    time_span = np.linspace(0, 0.2, 2000)  # 20 orbital periods and 500 points
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.0})
                    options.update({'t1': 0.2})
                    out = odesolve_adjoint_sym12(func, initial_condition[0,1+x1,1+x2,1+x3,1+x4], options=options)#, time_points=t_list)
                    
                    
                    
                    time_span = np.linspace(0.2, 0.4, 2000) 
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.2})
                    options.update({'t1': 0.4})
                    out2 = odesolve_adjoint_sym12(func, out, options=options)
                    
                    
                    time_span = np.linspace(0.4, 0.6, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.4})
                    options.update({'t1': 0.6})
                    out3 = odesolve_adjoint_sym12(func, out2, options=options)
                    
                    
                    time_span = np.linspace(0.6, 0.8, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.6})
                    options.update({'t1': 0.8})
                    out4 = odesolve_adjoint_sym12(func, out3, options=options)
                    
                    
                    time_span = np.linspace(0.8, 1.0, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.8})
                    options.update({'t1': 1.0})
                    out5 = odesolve_adjoint_sym12(func, out4, options=options)
                    
                    
                    
                    position = out[..., :2]
                    position2 = out2[..., :2]
                    position3 = out3[..., :2]
                    position4 = out4[..., :2]
                    position5 = out5[..., :2]
                    
                    loss = torch.norm(position - traj_q02[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss2 = torch.norm(position2 - traj_q04[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss3 = torch.norm(position3 - traj_q06[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss4 = torch.norm(position4 - traj_q08[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss5 = torch.norm(position5 - traj_q1[0,1+x1,1+x2,1+x3,1+x4])**2
                    
                    
                    loss =+ loss2 + loss3 + loss4 + loss5
                    l =+ torch.norm(loss).item()
                    
    l = l/(81)
    valuesALF[k] = l    
    steps[k] = (np.pi/4 - (150-k)*0.000001)
    
l = 0.0    
for k in range(300):
    options.update({'method': 'fixedstep_yoshida_alf2'})   
    
    func.a = torch.nn.Parameter(torch.ones(1)*(np.pi/4 - (150-k)*0.000001))
    
    #print('a {}, estimated a {}'.format(a, func.a.item()))
    for x1 in range(-1,2):
        for x2 in range(-1,2):
            for x3 in range(-1,2):
                for x4 in range(-1,2):
                    time_span = np.linspace(0, 0.2, 2000)  # 20 orbital periods and 500 points
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.0})
                    options.update({'t1': 0.2})
                    out = odesolve_adjoint_sym12(func, initial_condition[0,1+x1,1+x2,1+x3,1+x4], options=options)#, time_points=t_list)
                    
                    
                    
                    time_span = np.linspace(0.2, 0.4, 2000) 
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.2})
                    options.update({'t1': 0.4})
                    out2 = odesolve_adjoint_sym12(func, out, options=options)
                    
                    
                    time_span = np.linspace(0.4, 0.6, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.4})
                    options.update({'t1': 0.6})
                    out3 = odesolve_adjoint_sym12(func, out2, options=options)
                    
                    
                    time_span = np.linspace(0.6, 0.8, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.6})
                    options.update({'t1': 0.8})
                    out4 = odesolve_adjoint_sym12(func, out3, options=options)
                    
                    
                    time_span = np.linspace(0.8, 1.0, 2000)  
                    t_list = time_span.tolist()
                    options.update({'t_eval':t_list})
                    options.update({'t0': 0.8})
                    options.update({'t1': 1.0})
                    out5 = odesolve_adjoint_sym12(func, out4, options=options)
                    
                    
                    
                    position = out[..., :2]
                    position2 = out2[..., :2]
                    position3 = out3[..., :2]
                    position4 = out4[..., :2]
                    position5 = out5[..., :2]
                    
                    loss = torch.norm(position - traj_q02[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss2 = torch.norm(position2 - traj_q04[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss3 = torch.norm(position3 - traj_q06[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss4 = torch.norm(position4 - traj_q08[0,1+x1,1+x2,1+x3,1+x4])**2
                    loss5 = torch.norm(position5 - traj_q1[0,1+x1,1+x2,1+x3,1+x4])**2
                    
                    
                    loss =+ loss2 + loss3 + loss4 + loss5
                    l =+ torch.norm(loss).item()
                    
    l = l/(81)
    valuesYoshida[k] =l
     
np.savetxt('steps.txt', steps, delimiter=',', newline='\n' )  
np.savetxt('ALF.txt', valuesALF, delimiter=',', newline='\n' )  
np.savetxt('Yoshida.txt', valuesYoshida, delimiter=',', newline='\n' )  


fig = plt.figure(figsize=(15, 15))

ax = fig.add_subplot(111)

ax.plot(steps,valuesALF, label='ALF')
ax.plot(steps,valuesYoshida, label='Yoshida')



plt.axhline(y = 0, color = 'black')
plt.axvline(x = a, color = 'purple', label = 'true value')

ax.set_xlabel("parameter", fontsize=14)
ax.set_ylabel("loss", fontsize=14)

ax.legend(loc="upper left", fontsize=14)

 

