



from TorchDiffEqPack import odesolve_adjoint_sym12
import torch
from torch import nn
import numpy as np
import random 


    


lr = 1e-1 


TrainMode = True




# Define constant


a1 = 2.0
a2 = 0.7
a3 = -0.5
a4 = 2.4
a5 = 0.8
a6 = -2.4
a7 = -1.3
a8 = 0.3
a9 = 2.7
a10 = 2.8
b1 = -0.4
b2 = 3.0
b3 = -1.4
b4 = 1.9
b5 = -0.5
b6 = 3.0
b7 = 1.2
b8 = -1.4
b9 = -0.3
b10 = 1.7
e12 = 1.0
e13 = 0.9
e14 = 0.3
e15 = 0.7
e16 = 1.0
e17 = 1.1
e18 = 0.6
e19 = 1.3
e110 = 1.2
e23 = 1.5
e24 = 0.5
e25 = 0.9
e26 = 0.9
e27 = 1.3
e28 = 1.3
e29 = 0.7
e210 = 1.0
e34 = 1.3
e35 = 1.0
e36 = 0.9
e37 = 1.1
e38 = 1.2
e39 = 1.3
e310 = 1.0
e45 = 0.8
e46 = 1.2
e47 = 1.2
e48 = 0.7
e49 = 0.6
e410 = 1.0
e56 = 1.5
e57 = 0.8
e58 = 1.1
e59 = 0.7
e510 = 1.3
e67 = 0.8
e68 = 1.0
e69 = 1.2
e610 = 1.4
e78 = 1.5
e79 = 1.0
e710 = 0.6
e89 = 0.6
e810 = 0.8
e910 = 1.3


initial_data = np.loadtxt('Initial_data.txt', delimiter=',' )
initial_data = torch.from_numpy(np.array(initial_data)).float()


q01 = np.loadtxt('Final_points.txt', delimiter=',' )
traj_q01 = torch.from_numpy(q01).float()
traj_q01 = torch.unsqueeze(traj_q01, 0)


       


class TorchDuffingEquations(nn.Module):
    def __init__(self):
        super(TorchDuffingEquations, self).__init__()
        self.a1 = nn.Parameter(torch.ones(1)*a1*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a2 = nn.Parameter(torch.ones(1)*a2*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b1 = nn.Parameter(torch.ones(1)*b1*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b2 = nn.Parameter(torch.ones(1)*b2*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a3 = nn.Parameter(torch.ones(1)*a3*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a4 = nn.Parameter(torch.ones(1)*a4*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b3 = nn.Parameter(torch.ones(1)*b3*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b4 = nn.Parameter(torch.ones(1)*b4*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a5 = nn.Parameter(torch.ones(1)*a5*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a6 = nn.Parameter(torch.ones(1)*a6*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b5 = nn.Parameter(torch.ones(1)*b5*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b6 = nn.Parameter(torch.ones(1)*b6*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a7 = nn.Parameter(torch.ones(1)*a7*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a8 = nn.Parameter(torch.ones(1)*a8*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b7 = nn.Parameter(torch.ones(1)*b7*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b8 = nn.Parameter(torch.ones(1)*b8*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a9 = nn.Parameter(torch.ones(1)*a9*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.a10 = nn.Parameter(torch.ones(1)*a10*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b9 = nn.Parameter(torch.ones(1)*b9*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.b10 = nn.Parameter(torch.ones(1)*b10*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        
        self.e12 = nn.Parameter(torch.ones(1)*e12*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e13 = nn.Parameter(torch.ones(1)*e13*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e14 = nn.Parameter(torch.ones(1)*e14*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e15 = nn.Parameter(torch.ones(1)*e15*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e16 = nn.Parameter(torch.ones(1)*e16*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e17 = nn.Parameter(torch.ones(1)*e17*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e18 = nn.Parameter(torch.ones(1)*e18*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e19 = nn.Parameter(torch.ones(1)*e19*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e110 = nn.Parameter(torch.ones(1)*e210*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e23 = nn.Parameter(torch.ones(1)*e23*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e24 = nn.Parameter(torch.ones(1)*e24*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e25 = nn.Parameter(torch.ones(1)*e25*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e26 = nn.Parameter(torch.ones(1)*e26*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e27 = nn.Parameter(torch.ones(1)*e27*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e28 = nn.Parameter(torch.ones(1)*e28*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e29 = nn.Parameter(torch.ones(1)*e29*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e210 = nn.Parameter(torch.ones(1)*e210*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e34 = nn.Parameter(torch.ones(1)*e34*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e35 = nn.Parameter(torch.ones(1)*e35*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e36 = nn.Parameter(torch.ones(1)*e36*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e37 = nn.Parameter(torch.ones(1)*e37*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e38 = nn.Parameter(torch.ones(1)*e38*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e39 = nn.Parameter(torch.ones(1)*e39*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e310 = nn.Parameter(torch.ones(1)*e310*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e45 = nn.Parameter(torch.ones(1)*e45*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e46 = nn.Parameter(torch.ones(1)*e46*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e47 = nn.Parameter(torch.ones(1)*e47*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e48 = nn.Parameter(torch.ones(1)*e48*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e49 = nn.Parameter(torch.ones(1)*e49*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e410 = nn.Parameter(torch.ones(1)*e410*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e56 = nn.Parameter(torch.ones(1)*e56*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e57 = nn.Parameter(torch.ones(1)*e57*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e58 = nn.Parameter(torch.ones(1)*e58*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e59 = nn.Parameter(torch.ones(1)*e59*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e510 = nn.Parameter(torch.ones(1)*e510*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e67 = nn.Parameter(torch.ones(1)*e67*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e68 = nn.Parameter(torch.ones(1)*e68*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e69 = nn.Parameter(torch.ones(1)*e69*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e610 = nn.Parameter(torch.ones(1)*e610*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e78 = nn.Parameter(torch.ones(1)*e78*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e79 = nn.Parameter(torch.ones(1)*e79*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e710 = nn.Parameter(torch.ones(1)*e710*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e89 = nn.Parameter(torch.ones(1)*e89*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e810 = nn.Parameter(torch.ones(1)*e810*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        self.e910 = nn.Parameter(torch.ones(1)*e910*(1.0 + 1.0*np.random.random_sample(1)-0.5))
        
        
        
    
    

    
    
    def forward(self, t, w):
        
        q = w[...,:10]
        p = w[...,10:20]
        
        A=torch.ones(1, 10)
        A[0,0] = - self.a1*q[0] - self.b1*q[0]**3 - self.e12*(q[0] - q[1]) - self.e13*(q[0] - q[2])   - self.e14*(q[0] - q[3])   - self.e15*(q[0] - q[4])   - self.e16*(q[0] - q[5])   - self.e17*(q[0] - q[6])   - self.e18*(q[0] - q[7])   - self.e19*(q[0] - q[8])   - self.e110*(q[0] - q[9])        #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,1] = - self.a2*q[1] - self.b2*q[1]**3 - self.e12*(q[1] - q[0]) - self.e23*(q[1] - q[2])   - self.e24*(q[1] - q[3])   - self.e25*(q[1] - q[4])   - self.e26*(q[1] - q[5])   - self.e27*(q[1] - q[6])   - self.e28*(q[1] - q[7])   - self.e29*(q[1] - q[8])   - self.e210*(q[1] - q[9]) 
        A[0,2] = - self.a3*q[2] - self.b3*q[2]**3 - self.e13*(q[2] - q[0]) - self.e23*(q[2] - q[1])   - self.e34*(q[2] - q[3])   - self.e35*(q[2] - q[4])   - self.e36*(q[2] - q[5])   - self.e37*(q[2] - q[6])   - self.e38*(q[2] - q[7])   - self.e39*(q[2] - q[8])   - self.e310*(q[2] - q[9])    #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,3] = - self.a4*q[3] - self.b4*q[3]**3 - self.e14*(q[3] - q[0]) - self.e34*(q[3] - q[2])   - self.e24*(q[3] - q[1])   - self.e45*(q[3] - q[4])   - self.e46*(q[3] - q[5])   - self.e47*(q[3] - q[6])   - self.e48*(q[3] - q[7])   - self.e49*(q[3] - q[8])   - self.e410*(q[3] - q[9]) 
        A[0,4] = - self.a5*q[4] - self.b5*q[4]**3 - self.e15*(q[4] - q[0]) - self.e35*(q[4] - q[2])   - self.e45*(q[4] - q[3])   - self.e25*(q[4] - q[1])   - self.e56*(q[4] - q[5])   - self.e57*(q[4] - q[6])   - self.e58*(q[4] - q[7])   - self.e59*(q[4] - q[8])   - self.e510*(q[4] - q[9])   #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,5] = - self.a6*q[5] - self.b6*q[5]**3 - self.e16*(q[5] - q[0]) - self.e36*(q[5] - q[2])   - self.e46*(q[5] - q[3])   - self.e56*(q[5] - q[4])   - self.e26*(q[5] - q[1])   - self.e67*(q[5] - q[6])   - self.e68*(q[5] - q[7])   - self.e69*(q[5] - q[8])   - self.e610*(q[5] - q[9]) 
        A[0,6] = - self.a7*q[6] - self.b7*q[6]**3 - self.e17*(q[6] - q[0]) - self.e37*(q[6] - q[2])   - self.e47*(q[6] - q[3])   - self.e57*(q[6] - q[4])   - self.e67*(q[6] - q[5])   - self.e27*(q[6] - q[1])   - self.e78*(q[6] - q[7])   - self.e79*(q[6] - q[8])   - self.e710*(q[6] - q[9])  #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,7] = - self.a8*q[7] - self.b8*q[7]**3 - self.e18*(q[7] - q[0]) - self.e38*(q[7] - q[2])   - self.e48*(q[7] - q[3])   - self.e58*(q[7] - q[4])   - self.e68*(q[7] - q[5])   - self.e78*(q[7] - q[6])   - self.e28*(q[7] - q[1])   - self.e89*(q[7] - q[8])   - self.e810*(q[7] - q[9]) 
        A[0,8] = - self.a9*q[8] - self.b9*q[8]**3 - self.e19*(q[8] - q[0]) - self.e39*(q[8] - q[2])   - self.e49*(q[8] - q[3])   - self.e59*(q[8] - q[4])   - self.e69*(q[8] - q[5])   - self.e79*(q[8] - q[6])   - self.e89*(q[8] - q[7])   - self.e29*(q[8] - q[1])   - self.e910*(q[8] - q[9])  #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        A[0,9] = - self.a10*q[9] - self.b10*q[9]**3 - self.e110*(q[9] - q[0]) - self.e310*(q[9] - q[2])   - self.e410*(q[9] - q[3])   - self.e510*(q[9] - q[4])   - self.e610*(q[9] - q[5])   - self.e710*(q[9] - q[6])   - self.e810*(q[9] - q[7])   - self.e910*(q[9] - q[8])   - self.e210*(q[9] - q[1])  #-self.dV3(torch.reshape(q[1],(1,))) - self.dV4(torch.reshape(q[1],(1,))) - self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))
        
        
        dqdt = p
        dpdt = torch.reshape(A,(-1,))
        
        
        derivs = torch.cat((dqdt,dpdt), dim=-1)
        return derivs



func = TorchDuffingEquations()







options = {}
options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async yoshida_alf2 sym12async suzuki_alf2
options.update({'h': None})
options.update({'t0': 0.0})
options.update({'t1': 0.5})
options.update({'rtol': 1e-4})
options.update({'atol': 1e-5})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})
options.update({'t_eval':None})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':True})


optimizer = torch.optim.AdamW(func.parameters(), lr=lr, betas=(0.50, 0.50), eps=1e-08, weight_decay=0.01, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=.99) 





torch.manual_seed(21)
random.seed(21)
np.random.seed(21)

func.train()


NbTraj = 200

best_loss = np.inf
import time

loss = 10.0
i = 0
n_samples = 200
TrainLoss = np.zeros((500))


func.eval()

error = (np.linalg.norm(e12- func.e12.item())+np.linalg.norm(e13- func.e13.item())+np.linalg.norm(e14- func.e14.item())+np.linalg.norm(e15- func.e15.item())+np.linalg.norm(e16- func.e16.item())+np.linalg.norm(e17- func.e17.item())+np.linalg.norm(e18- func.e18.item())+np.linalg.norm(e19- func.e19.item())+np.linalg.norm(e110- func.e110.item())+np.linalg.norm(e23- func.e23.item())+np.linalg.norm(e24- func.e24.item())+np.linalg.norm(e25- func.e25.item())+np.linalg.norm(e26- func.e26.item())+np.linalg.norm(e27- func.e27.item())+np.linalg.norm(e28- func.e28.item())+np.linalg.norm(e29- func.e29.item())+np.linalg.norm(e210- func.e210.item())+np.linalg.norm(e34- func.e34.item())+np.linalg.norm(e35- func.e35.item())+np.linalg.norm(e36- func.e36.item())+np.linalg.norm(e37- func.e37.item())+np.linalg.norm(e38- func.e38.item())+np.linalg.norm(e39- func.e39.item())+np.linalg.norm(e310- func.e310.item())+np.linalg.norm(e45- func.e45.item())+np.linalg.norm(e46- func.e46.item())+np.linalg.norm(e47- func.e47.item())+np.linalg.norm(e48- func.e48.item())+np.linalg.norm(e49- func.e49.item())+np.linalg.norm(e410- func.e410.item())+np.linalg.norm(e56- func.e56.item())+np.linalg.norm(e57- func.e57.item())+np.linalg.norm(e58- func.e58.item())+np.linalg.norm(e59- func.e59.item())+np.linalg.norm(e510- func.e510.item())+np.linalg.norm(e67- func.e67.item())+np.linalg.norm(e68- func.e68.item())+np.linalg.norm(e69- func.e69.item())+np.linalg.norm(e610- func.e610.item())+np.linalg.norm(e78- func.e78.item())+np.linalg.norm(e79- func.e79.item())+np.linalg.norm(e710- func.e710.item())+np.linalg.norm(e89- func.e89.item())+np.linalg.norm(e810- func.e810.item()) + np.linalg.norm(e910- func.e910.item()) + np.linalg.norm(a1- func.a1.item())+ np.linalg.norm(a2- func.a2.item())+ np.linalg.norm(a3- func.a3.item())+ np.linalg.norm(a4- func.a4.item())+ np.linalg.norm(a5- func.a5.item()) + np.linalg.norm(a6- func.a6.item())+ np.linalg.norm(a7- func.a7.item())+ np.linalg.norm(a8- func.a8.item())+ np.linalg.norm(a9- func.a9.item())+ np.linalg.norm(a10- func.a10.item()) + np.linalg.norm(b1- func.b1.item()) + np.linalg.norm(b2- func.b2.item())+ np.linalg.norm(b3 - func.b3.item()) + np.linalg.norm(b4- func.b4.item()) + np.linalg.norm(b5- func.b5.item()) + np.linalg.norm(b6- func.b6.item()) + np.linalg.norm(b7- func.b7.item()) + np.linalg.norm(b8- func.b8.item()) + np.linalg.norm(b9- func.b9.item()) + np.linalg.norm(b10- func.b10.item()))/65.0
print('Initial error %.8f'% error)  





if TrainMode:
    start_time = time.time()
    while loss > 1e-4 and i<500:
        
        i = i+1
        
        optimizer.zero_grad()
        
        
  
        func.eval()
        
        
        
        x1 = random.randint(0,n_samples-1)
        
        
        out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)
        
        
        position01 = out01[..., :20]
        dif01 = position01 - traj_q01[0,x1]
        dif01 = torch.sum(dif01 ** 2, -1, keepdim=False)  
        
        dif101 = torch.squeeze(dif01)  
        l =torch.sum(torch.abs(dif101))
        
        for k in range(NbTraj-1):
            x1 = random.randint(0,n_samples-1)
            
            out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)
            
            
            position01 = out01[..., :20]
            dif01 = position01 - traj_q01[0,x1]
            dif01 = torch.sum(dif01 ** 2, -1, keepdim=False) 
            
            dif101 = torch.squeeze(dif01)  
            l =l + torch.sum(torch.abs(dif101))
            
        
        
        
        
        
        l = l/float(NbTraj)
        loss = torch.norm(l).item()
        TrainLoss[i-1] = torch.norm(l).item()
        
        l.backward()

        optimizer.step()
        scheduler.step()
        
        print('Epoch %d: Loss: %.8f' % (i, l.item()))
    print('Finished training')
    print("--- %s seconds ---" % (time.time() - start_time))


np.savetxt('TrainLoss.txt',TrainLoss, delimiter=',', newline='\n' )


func.eval()

error = (np.linalg.norm(e12- func.e12.item())+np.linalg.norm(e13- func.e13.item())+np.linalg.norm(e14- func.e14.item())+np.linalg.norm(e15- func.e15.item())+np.linalg.norm(e16- func.e16.item())+np.linalg.norm(e17- func.e17.item())+np.linalg.norm(e18- func.e18.item())+np.linalg.norm(e19- func.e19.item())+np.linalg.norm(e110- func.e110.item())+np.linalg.norm(e23- func.e23.item())+np.linalg.norm(e24- func.e24.item())+np.linalg.norm(e25- func.e25.item())+np.linalg.norm(e26- func.e26.item())+np.linalg.norm(e27- func.e27.item())+np.linalg.norm(e28- func.e28.item())+np.linalg.norm(e29- func.e29.item())+np.linalg.norm(e210- func.e210.item())+np.linalg.norm(e34- func.e34.item())+np.linalg.norm(e35- func.e35.item())+np.linalg.norm(e36- func.e36.item())+np.linalg.norm(e37- func.e37.item())+np.linalg.norm(e38- func.e38.item())+np.linalg.norm(e39- func.e39.item())+np.linalg.norm(e310- func.e310.item())+np.linalg.norm(e45- func.e45.item())+np.linalg.norm(e46- func.e46.item())+np.linalg.norm(e47- func.e47.item())+np.linalg.norm(e48- func.e48.item())+np.linalg.norm(e49- func.e49.item())+np.linalg.norm(e410- func.e410.item())+np.linalg.norm(e56- func.e56.item())+np.linalg.norm(e57- func.e57.item())+np.linalg.norm(e58- func.e58.item())+np.linalg.norm(e59- func.e59.item())+np.linalg.norm(e510- func.e510.item())+np.linalg.norm(e67- func.e67.item())+np.linalg.norm(e68- func.e68.item())+np.linalg.norm(e69- func.e69.item())+np.linalg.norm(e610- func.e610.item())+np.linalg.norm(e78- func.e78.item())+np.linalg.norm(e79- func.e79.item())+np.linalg.norm(e710- func.e710.item())+np.linalg.norm(e89- func.e89.item())+np.linalg.norm(e810- func.e810.item()) + np.linalg.norm(e910- func.e910.item()) + np.linalg.norm(a1- func.a1.item())+ np.linalg.norm(a2- func.a2.item())+ np.linalg.norm(a3- func.a3.item())+ np.linalg.norm(a4- func.a4.item())+ np.linalg.norm(a5- func.a5.item()) + np.linalg.norm(a6- func.a6.item())+ np.linalg.norm(a7- func.a7.item())+ np.linalg.norm(a8- func.a8.item())+ np.linalg.norm(a9- func.a9.item())+ np.linalg.norm(a10- func.a10.item()) + np.linalg.norm(b1- func.b1.item()) + np.linalg.norm(b2- func.b2.item())+ np.linalg.norm(b3 - func.b3.item()) + np.linalg.norm(b4- func.b4.item()) + np.linalg.norm(b5- func.b5.item()) + np.linalg.norm(b6- func.b6.item()) + np.linalg.norm(b7- func.b7.item()) + np.linalg.norm(b8- func.b8.item()) + np.linalg.norm(b9- func.b9.item()) + np.linalg.norm(b10- func.b10.item()))/65.0
    
print('Final error %.8f'% error)


