import scipy as sci
from TorchDiffEqPack import odesolve_adjoint_sym12
import torch
from torch import nn
import numpy as np







lr = 1e-1
lr_decay = 0.95



TrainMode = True





## Kepler
a = np.pi/4




x0 = [0.75,  0, 0, 0.9*(np.pi/4)*np.sqrt(5/3)]
x0 = np.array(x0, dtype="float64")


init_params = np.array(x0)  
init_params = init_params.flatten()  


def KeplerEquations(w, t, a):
    q = w[...,:2]
    p = w[...,2:4]
    

    q12 = sci.linalg.norm(q)  
    dqdt = p
    dpdt = (-a/q12**(3))*q
    
    
    
   
    derivs = np.concatenate((dqdt, dpdt))
    return derivs

q02 = np.zeros((2))
q04 = np.zeros((2))
q06 = np.zeros((2))
q08 = np.zeros((2))
q1 = np.zeros((2))




import scipy.integrate

time_span = np.linspace(0.0, 0.2, 2000) 
sol02 = sci.integrate.odeint(KeplerEquations, init_params, time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
q02 = sol02[-1,0:2]
                

                
time_span = np.linspace(0.2, 0.4, 2000) 
sol04 = sci.integrate.odeint(KeplerEquations, sol02[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
q04 = sol04[-1,0:2]
                
                
time_span = np.linspace(0.4, 0.6, 2000)

sol06 = sci.integrate.odeint(KeplerEquations, sol04[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
q06 = sol06[-1,0:2]
                
time_span = np.linspace(0.6, 0.8, 2000) 

sol08 = sci.integrate.odeint(KeplerEquations, sol06[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
q08 = sol08[-1,0:2]
                
time_span = np.linspace(0.8, 1.0, 2000) 

sol1 = sci.integrate.odeint(KeplerEquations, sol08[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)
q1 = sol1[-1,0:2]
                

initial_condition = torch.from_numpy(x0).float()
initial_condition = torch.unsqueeze(initial_condition, 0)

traj_q02 = torch.from_numpy(q02).float()
traj_q02 = torch.unsqueeze(traj_q02, 0)


traj_q04 = torch.from_numpy(q04).float()
traj_q04 = torch.unsqueeze(traj_q04, 0) 

traj_q06 = torch.from_numpy(q06).float()
traj_q06 = torch.unsqueeze(traj_q06, 0) 

traj_q08 = torch.from_numpy(q08).float()
traj_q08 = torch.unsqueeze(traj_q08, 0) 

traj_q1 = torch.from_numpy(q1).float()
traj_q1 = torch.unsqueeze(traj_q1, 0) 




class TorchKeplerEquations(nn.Module):
    def __init__(self):
        super(TorchKeplerEquations, self).__init__()
        
        self.a = nn.Parameter(torch.ones(1)*0.1)
        

    def forward(self, t, w):
        
        q = w[...,:2]
        p = w[...,2:4]

        
        q12 = torch.norm(q)  
        dqdt = p
        dpdt = (-self.a/q12**(3))*q

       
     
        derivs = torch.cat((dqdt,dpdt), dim=-1)
        return derivs





func = TorchKeplerEquations()



options = {}
options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async suzuki_alf2
options.update({'h': 0.01})
options.update({'t0': 0.0})
options.update({'t1': 1.0})
options.update({'rtol': 1e-4})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})

options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':False})


optimizer = torch.optim.SGD(func.parameters(),lr=0.1) 


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        lr_old = param_group['lr']
        param_group['lr'] = lr



  
func.train()







l = 10.0
i = 0

import time

if TrainMode:
    start_time = time.time()
    while l > 1e-8 and i<600:
    
        print('a {}, estimated a {}'.format(a, func.a.item()))

        i = i+1
        lr *= lr_decay
        adjust_learning_rate(optimizer, lr)
        optimizer.zero_grad()


        func.eval()
        time_span = np.linspace(0, 0.2, 2000)  
        t_list = time_span.tolist()
        options.update({'t_eval':t_list})
        options.update({'t0': 0.0})
        options.update({'t1': 0.2})

        out = odesolve_adjoint_sym12(func, initial_condition, options=options)
        
        
        
        time_span = np.linspace(0.2, 0.4, 2000)  
        t_list = time_span.tolist()
        options.update({'t_eval':t_list})
        options.update({'t0': 0.2})
        options.update({'t1': 0.4})
        out2 = odesolve_adjoint_sym12(func, out, options=options)
        
        
        time_span = np.linspace(0.4, 0.6, 2000)  
        t_list = time_span.tolist()
        options.update({'t_eval':t_list})
        options.update({'t0': 0.4})
        options.update({'t1': 0.6})
        out3 = odesolve_adjoint_sym12(func, out2, options=options)
        
        
        time_span = np.linspace(0.6, 0.8, 2000)  
        t_list = time_span.tolist()
        options.update({'t_eval':t_list})
        options.update({'t0': 0.6})
        options.update({'t1': 0.8})
        out4 = odesolve_adjoint_sym12(func, out3, options=options)
        
        time_span = np.linspace(0.8, 1.0, 2000) 
        t_list = time_span.tolist()
        options.update({'t_eval':t_list})
        options.update({'t0': 0.8})
        options.update({'t1': 1.0})
        out5 = odesolve_adjoint_sym12(func, out4, options=options)
        
        
        
        position = out[..., :2]
        position2 = out2[..., :2]
        position3 = out3[..., :2]
        position4 = out4[..., :2]
        position5 = out5[..., :2]
          
        
        
        loss = torch.norm(position - traj_q02[0])**2
        loss2 = torch.norm(position2 - traj_q04[0])**2
        loss3 = torch.norm(position3 - traj_q06[0])**2
        loss4 = torch.norm(position4 - traj_q08[0])**2
        loss5 = torch.norm(position5 - traj_q1[0])**2
        
        loss =+ loss2 + loss3 + loss4 + loss5
        l = torch.norm(loss).item()
        
        
       

        loss.backward()

        optimizer.step()
        print('Epoch %d: Loss: %.8f' % (i, loss.item()))
        
       

    print('Finished training')
    print("--- %s seconds ---" % (time.time() - start_time))

