{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aced20a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from TorchDiffEqPack import odesolve_adjoint_sym12\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import random "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3850b0d",
   "metadata": {},
   "source": [
    "Parameters of the coupled oscillators problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a45de0dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "a1 = 2.0\n",
    "a2 = 0.7\n",
    "a3 = -0.5\n",
    "a4 = 2.4\n",
    "a5 = 0.8\n",
    "a6 = -2.4\n",
    "a7 = -1.3\n",
    "a8 = 0.3\n",
    "a9 = 2.7\n",
    "a10 = 2.8\n",
    "b1 = -0.4\n",
    "b2 = 3.0\n",
    "b3 = -1.4\n",
    "b4 = 1.9\n",
    "b5 = -0.5\n",
    "b6 = 3.0\n",
    "b7 = 1.2\n",
    "b8 = -1.4\n",
    "b9 = -0.3\n",
    "b10 = 1.7\n",
    "e12 = 1.0\n",
    "e13 = 0.9\n",
    "e14 = 0.3\n",
    "e15 = 0.7\n",
    "e16 = 1.0\n",
    "e17 = 1.1\n",
    "e18 = 0.6\n",
    "e19 = 1.3\n",
    "e110 = 1.2\n",
    "e23 = 1.5\n",
    "e24 = 0.5\n",
    "e25 = 0.9\n",
    "e26 = 0.9\n",
    "e27 = 1.3\n",
    "e28 = 1.3\n",
    "e29 = 0.7\n",
    "e210 = 1.0\n",
    "e34 = 1.3\n",
    "e35 = 1.0\n",
    "e36 = 0.9\n",
    "e37 = 1.1\n",
    "e38 = 1.2\n",
    "e39 = 1.3\n",
    "e310 = 1.0\n",
    "e45 = 0.8\n",
    "e46 = 1.2\n",
    "e47 = 1.2\n",
    "e48 = 0.7\n",
    "e49 = 0.6\n",
    "e410 = 1.0\n",
    "e56 = 1.5\n",
    "e57 = 0.8\n",
    "e58 = 1.1\n",
    "e59 = 0.7\n",
    "e510 = 1.3\n",
    "e67 = 0.8\n",
    "e68 = 1.0\n",
    "e69 = 1.2\n",
    "e610 = 1.4\n",
    "e78 = 1.5\n",
    "e79 = 1.0\n",
    "e710 = 0.6\n",
    "e89 = 0.6\n",
    "e810 = 0.8\n",
    "e910 = 1.3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1c3f482",
   "metadata": {},
   "source": [
    "Training data consists of initial and final points of 200 trajectories "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ab500d0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_data = np.loadtxt('Initial_data.txt', delimiter=',' )\n",
    "initial_data = torch.from_numpy(np.array(initial_data)).float()\n",
    "\n",
    "\n",
    "q01 = np.loadtxt('Final_points.txt', delimiter=',' )\n",
    "traj_q01 = torch.from_numpy(q01).float()\n",
    "traj_q01 = torch.unsqueeze(traj_q01, 0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c767edfd",
   "metadata": {},
   "source": [
    "torch function "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2bec41a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchDuffingEquations(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TorchDuffingEquations, self).__init__()\n",
    "        self.a1 = nn.Parameter(torch.ones(1)*a1*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a2 = nn.Parameter(torch.ones(1)*a2*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b1 = nn.Parameter(torch.ones(1)*b1*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b2 = nn.Parameter(torch.ones(1)*b2*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a3 = nn.Parameter(torch.ones(1)*a3*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a4 = nn.Parameter(torch.ones(1)*a4*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b3 = nn.Parameter(torch.ones(1)*b3*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b4 = nn.Parameter(torch.ones(1)*b4*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a5 = nn.Parameter(torch.ones(1)*a5*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a6 = nn.Parameter(torch.ones(1)*a6*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b5 = nn.Parameter(torch.ones(1)*b5*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b6 = nn.Parameter(torch.ones(1)*b6*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a7 = nn.Parameter(torch.ones(1)*a7*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a8 = nn.Parameter(torch.ones(1)*a8*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b7 = nn.Parameter(torch.ones(1)*b7*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b8 = nn.Parameter(torch.ones(1)*b8*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a9 = nn.Parameter(torch.ones(1)*a9*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.a10 = nn.Parameter(torch.ones(1)*a10*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b9 = nn.Parameter(torch.ones(1)*b9*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.b10 = nn.Parameter(torch.ones(1)*b10*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        \n",
    "        self.e12 = nn.Parameter(torch.ones(1)*e12*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e13 = nn.Parameter(torch.ones(1)*e13*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e14 = nn.Parameter(torch.ones(1)*e14*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e15 = nn.Parameter(torch.ones(1)*e15*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e16 = nn.Parameter(torch.ones(1)*e16*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e17 = nn.Parameter(torch.ones(1)*e17*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e18 = nn.Parameter(torch.ones(1)*e18*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e19 = nn.Parameter(torch.ones(1)*e19*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e110 = nn.Parameter(torch.ones(1)*e210*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e23 = nn.Parameter(torch.ones(1)*e23*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e24 = nn.Parameter(torch.ones(1)*e24*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e25 = nn.Parameter(torch.ones(1)*e25*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e26 = nn.Parameter(torch.ones(1)*e26*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e27 = nn.Parameter(torch.ones(1)*e27*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e28 = nn.Parameter(torch.ones(1)*e28*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e29 = nn.Parameter(torch.ones(1)*e29*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e210 = nn.Parameter(torch.ones(1)*e210*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e34 = nn.Parameter(torch.ones(1)*e34*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e35 = nn.Parameter(torch.ones(1)*e35*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e36 = nn.Parameter(torch.ones(1)*e36*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e37 = nn.Parameter(torch.ones(1)*e37*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e38 = nn.Parameter(torch.ones(1)*e38*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e39 = nn.Parameter(torch.ones(1)*e39*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e310 = nn.Parameter(torch.ones(1)*e310*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e45 = nn.Parameter(torch.ones(1)*e45*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e46 = nn.Parameter(torch.ones(1)*e46*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e47 = nn.Parameter(torch.ones(1)*e47*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e48 = nn.Parameter(torch.ones(1)*e48*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e49 = nn.Parameter(torch.ones(1)*e49*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e410 = nn.Parameter(torch.ones(1)*e410*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e56 = nn.Parameter(torch.ones(1)*e56*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e57 = nn.Parameter(torch.ones(1)*e57*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e58 = nn.Parameter(torch.ones(1)*e58*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e59 = nn.Parameter(torch.ones(1)*e59*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e510 = nn.Parameter(torch.ones(1)*e510*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e67 = nn.Parameter(torch.ones(1)*e67*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e68 = nn.Parameter(torch.ones(1)*e68*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e69 = nn.Parameter(torch.ones(1)*e69*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e610 = nn.Parameter(torch.ones(1)*e610*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e78 = nn.Parameter(torch.ones(1)*e78*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e79 = nn.Parameter(torch.ones(1)*e79*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e710 = nn.Parameter(torch.ones(1)*e710*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e89 = nn.Parameter(torch.ones(1)*e89*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e810 = nn.Parameter(torch.ones(1)*e810*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        self.e910 = nn.Parameter(torch.ones(1)*e910*(1.0 + 1.0*np.random.random_sample(1)-0.5))\n",
    "        \n",
    "        \n",
    "        \n",
    "    \n",
    "    \n",
    "\n",
    "    \n",
    "    \n",
    "    def forward(self, t, w):\n",
    "        \n",
    "        q = w[...,:10]\n",
    "        p = w[...,10:20]\n",
    "        \n",
    "        A=torch.ones(1, 10)\n",
    "        A[0,0] = - self.a1*q[0] - self.b1*q[0]**3 - self.e12*(q[0] - q[1]) - self.e13*(q[0] - q[2])   - self.e14*(q[0] - q[3])   - self.e15*(q[0] - q[4])   - self.e16*(q[0] - q[5])   - self.e17*(q[0] - q[6])   - self.e18*(q[0] - q[7])   - self.e19*(q[0] - q[8])   - self.e110*(q[0] - q[9])        #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,1] = - self.a2*q[1] - self.b2*q[1]**3 - self.e12*(q[1] - q[0]) - self.e23*(q[1] - q[2])   - self.e24*(q[1] - q[3])   - self.e25*(q[1] - q[4])   - self.e26*(q[1] - q[5])   - self.e27*(q[1] - q[6])   - self.e28*(q[1] - q[7])   - self.e29*(q[1] - q[8])   - self.e210*(q[1] - q[9]) \n",
    "        A[0,2] = - self.a3*q[2] - self.b3*q[2]**3 - self.e13*(q[2] - q[0]) - self.e23*(q[2] - q[1])   - self.e34*(q[2] - q[3])   - self.e35*(q[2] - q[4])   - self.e36*(q[2] - q[5])   - self.e37*(q[2] - q[6])   - self.e38*(q[2] - q[7])   - self.e39*(q[2] - q[8])   - self.e310*(q[2] - q[9])    #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,3] = - self.a4*q[3] - self.b4*q[3]**3 - self.e14*(q[3] - q[0]) - self.e34*(q[3] - q[2])   - self.e24*(q[3] - q[1])   - self.e45*(q[3] - q[4])   - self.e46*(q[3] - q[5])   - self.e47*(q[3] - q[6])   - self.e48*(q[3] - q[7])   - self.e49*(q[3] - q[8])   - self.e410*(q[3] - q[9]) \n",
    "        A[0,4] = - self.a5*q[4] - self.b5*q[4]**3 - self.e15*(q[4] - q[0]) - self.e35*(q[4] - q[2])   - self.e45*(q[4] - q[3])   - self.e25*(q[4] - q[1])   - self.e56*(q[4] - q[5])   - self.e57*(q[4] - q[6])   - self.e58*(q[4] - q[7])   - self.e59*(q[4] - q[8])   - self.e510*(q[4] - q[9])   #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,5] = - self.a6*q[5] - self.b6*q[5]**3 - self.e16*(q[5] - q[0]) - self.e36*(q[5] - q[2])   - self.e46*(q[5] - q[3])   - self.e56*(q[5] - q[4])   - self.e26*(q[5] - q[1])   - self.e67*(q[5] - q[6])   - self.e68*(q[5] - q[7])   - self.e69*(q[5] - q[8])   - self.e610*(q[5] - q[9]) \n",
    "        A[0,6] = - self.a7*q[6] - self.b7*q[6]**3 - self.e17*(q[6] - q[0]) - self.e37*(q[6] - q[2])   - self.e47*(q[6] - q[3])   - self.e57*(q[6] - q[4])   - self.e67*(q[6] - q[5])   - self.e27*(q[6] - q[1])   - self.e78*(q[6] - q[7])   - self.e79*(q[6] - q[8])   - self.e710*(q[6] - q[9])  #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,7] = - self.a8*q[7] - self.b8*q[7]**3 - self.e18*(q[7] - q[0]) - self.e38*(q[7] - q[2])   - self.e48*(q[7] - q[3])   - self.e58*(q[7] - q[4])   - self.e68*(q[7] - q[5])   - self.e78*(q[7] - q[6])   - self.e28*(q[7] - q[1])   - self.e89*(q[7] - q[8])   - self.e810*(q[7] - q[9]) \n",
    "        A[0,8] = - self.a9*q[8] - self.b9*q[8]**3 - self.e19*(q[8] - q[0]) - self.e39*(q[8] - q[2])   - self.e49*(q[8] - q[3])   - self.e59*(q[8] - q[4])   - self.e69*(q[8] - q[5])   - self.e79*(q[8] - q[6])   - self.e89*(q[8] - q[7])   - self.e29*(q[8] - q[1])   - self.e910*(q[8] - q[9])  #-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,9] = - self.a10*q[9] - self.b10*q[9]**3 - self.e110*(q[9] - q[0]) - self.e310*(q[9] - q[2])   - self.e410*(q[9] - q[3])   - self.e510*(q[9] - q[4])   - self.e610*(q[9] - q[5])   - self.e710*(q[9] - q[6])   - self.e810*(q[9] - q[7])   - self.e910*(q[9] - q[8])   - self.e210*(q[9] - q[1])  #-self.dV3(torch.reshape(q[1],(1,))) - self.dV4(torch.reshape(q[1],(1,))) - self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        \n",
    "        \n",
    "        dqdt = p\n",
    "        dpdt = torch.reshape(A,(-1,))\n",
    "        \n",
    "        \n",
    "        derivs = torch.cat((dqdt,dpdt), dim=-1)\n",
    "        return derivs\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04224b08",
   "metadata": {},
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2f2acf66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial error 0.30549358\n",
      "Epoch 1: Loss: 0.15404925\n",
      "Epoch 2: Loss: 0.06942090\n",
      "Finished training\n",
      "--- 1351.631936788559 seconds ---\n",
      "Final error 0.24116988\n"
     ]
    }
   ],
   "source": [
    "lr = 1e-1 \n",
    "func = TorchDuffingEquations()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "options = {}\n",
    "options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async yoshida_alf2 sym12async suzuki_alf2\n",
    "options.update({'h': None})\n",
    "options.update({'t0': 0.0})\n",
    "options.update({'t1': 0.5})\n",
    "options.update({'rtol': 1e-4})\n",
    "options.update({'atol': 1e-5})\n",
    "options.update({'print_neval': False})\n",
    "options.update({'neval_max': 1000000})\n",
    "options.update({'safety': None})\n",
    "options.update({'t_eval':None})\n",
    "options.update({'interpolation_method':'cubic'})\n",
    "options.update({'regenerate_graph':True})\n",
    "\n",
    "\n",
    "optimizer = torch.optim.AdamW(func.parameters(), lr=lr, betas=(0.50, 0.50), eps=1e-08, weight_decay=0.01, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=.99) \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "torch.manual_seed(21)\n",
    "random.seed(21)\n",
    "np.random.seed(21)\n",
    "\n",
    "func.train()\n",
    "\n",
    "\n",
    "NbTraj = 200\n",
    "\n",
    "best_loss = np.inf\n",
    "import time\n",
    "\n",
    "loss = 10.0\n",
    "i = 0\n",
    "n_samples = 200\n",
    "TrainLoss = np.zeros((500))\n",
    "\n",
    "\n",
    "func.eval()\n",
    "\n",
    "error = (np.linalg.norm(e12- func.e12.item())+np.linalg.norm(e13- func.e13.item())+np.linalg.norm(e14- func.e14.item())+np.linalg.norm(e15- func.e15.item())+np.linalg.norm(e16- func.e16.item())+np.linalg.norm(e17- func.e17.item())+np.linalg.norm(e18- func.e18.item())+np.linalg.norm(e19- func.e19.item())+np.linalg.norm(e110- func.e110.item())+np.linalg.norm(e23- func.e23.item())+np.linalg.norm(e24- func.e24.item())+np.linalg.norm(e25- func.e25.item())+np.linalg.norm(e26- func.e26.item())+np.linalg.norm(e27- func.e27.item())+np.linalg.norm(e28- func.e28.item())+np.linalg.norm(e29- func.e29.item())+np.linalg.norm(e210- func.e210.item())+np.linalg.norm(e34- func.e34.item())+np.linalg.norm(e35- func.e35.item())+np.linalg.norm(e36- func.e36.item())+np.linalg.norm(e37- func.e37.item())+np.linalg.norm(e38- func.e38.item())+np.linalg.norm(e39- func.e39.item())+np.linalg.norm(e310- func.e310.item())+np.linalg.norm(e45- func.e45.item())+np.linalg.norm(e46- func.e46.item())+np.linalg.norm(e47- func.e47.item())+np.linalg.norm(e48- func.e48.item())+np.linalg.norm(e49- func.e49.item())+np.linalg.norm(e410- func.e410.item())+np.linalg.norm(e56- func.e56.item())+np.linalg.norm(e57- func.e57.item())+np.linalg.norm(e58- func.e58.item())+np.linalg.norm(e59- func.e59.item())+np.linalg.norm(e510- func.e510.item())+np.linalg.norm(e67- func.e67.item())+np.linalg.norm(e68- func.e68.item())+np.linalg.norm(e69- func.e69.item())+np.linalg.norm(e610- func.e610.item())+np.linalg.norm(e78- func.e78.item())+np.linalg.norm(e79- func.e79.item())+np.linalg.norm(e710- func.e710.item())+np.linalg.norm(e89- func.e89.item())+np.linalg.norm(e810- func.e810.item()) + np.linalg.norm(e910- func.e910.item()) + np.linalg.norm(a1- func.a1.item())+ np.linalg.norm(a2- func.a2.item())+ np.linalg.norm(a3- func.a3.item())+ np.linalg.norm(a4- func.a4.item())+ np.linalg.norm(a5- func.a5.item()) + np.linalg.norm(a6- func.a6.item())+ np.linalg.norm(a7- func.a7.item())+ np.linalg.norm(a8- func.a8.item())+ np.linalg.norm(a9- func.a9.item())+ np.linalg.norm(a10- func.a10.item()) + np.linalg.norm(b1- func.b1.item()) + np.linalg.norm(b2- func.b2.item())+ np.linalg.norm(b3 - func.b3.item()) + np.linalg.norm(b4- func.b4.item()) + np.linalg.norm(b5- func.b5.item()) + np.linalg.norm(b6- func.b6.item()) + np.linalg.norm(b7- func.b7.item()) + np.linalg.norm(b8- func.b8.item()) + np.linalg.norm(b9- func.b9.item()) + np.linalg.norm(b10- func.b10.item()))/65.0\n",
    "print('Initial error %.8f'% error)  \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "start_time = time.time()\n",
    "while loss > 1e-4 and i<500:\n",
    "        \n",
    "    i = i+1\n",
    "        \n",
    "    optimizer.zero_grad()\n",
    "        \n",
    "        \n",
    "  \n",
    "    func.eval()\n",
    "        \n",
    "        \n",
    "        \n",
    "    x1 = random.randint(0,n_samples-1)\n",
    "        \n",
    "        \n",
    "    out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)\n",
    "        \n",
    "        \n",
    "    position01 = out01[..., :20]\n",
    "    dif01 = position01 - traj_q01[0,x1]\n",
    "    dif01 = torch.sum(dif01 ** 2, -1, keepdim=False)  \n",
    "        \n",
    "    dif101 = torch.squeeze(dif01)  \n",
    "    l =torch.sum(torch.abs(dif101))\n",
    "        \n",
    "    for k in range(NbTraj-1):\n",
    "        x1 = random.randint(0,n_samples-1)\n",
    "            \n",
    "        out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)\n",
    "            \n",
    "            \n",
    "        position01 = out01[..., :20]\n",
    "        dif01 = position01 - traj_q01[0,x1]\n",
    "        dif01 = torch.sum(dif01 ** 2, -1, keepdim=False) \n",
    "            \n",
    "        dif101 = torch.squeeze(dif01)  \n",
    "        l =l + torch.sum(torch.abs(dif101))\n",
    "            \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "    l = l/float(NbTraj)\n",
    "    loss = torch.norm(l).item()\n",
    "    TrainLoss[i-1] = torch.norm(l).item()\n",
    "        \n",
    "    l.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "        \n",
    "    print('Epoch %d: Loss: %.8f' % (i, l.item()))\n",
    "print('Finished training')\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
    "\n",
    "\n",
    "np.savetxt('TrainLoss.txt',TrainLoss, delimiter=',', newline='\\n' )\n",
    "\n",
    "\n",
    "func.eval()\n",
    "\n",
    "error = (np.linalg.norm(e12- func.e12.item())+np.linalg.norm(e13- func.e13.item())+np.linalg.norm(e14- func.e14.item())+np.linalg.norm(e15- func.e15.item())+np.linalg.norm(e16- func.e16.item())+np.linalg.norm(e17- func.e17.item())+np.linalg.norm(e18- func.e18.item())+np.linalg.norm(e19- func.e19.item())+np.linalg.norm(e110- func.e110.item())+np.linalg.norm(e23- func.e23.item())+np.linalg.norm(e24- func.e24.item())+np.linalg.norm(e25- func.e25.item())+np.linalg.norm(e26- func.e26.item())+np.linalg.norm(e27- func.e27.item())+np.linalg.norm(e28- func.e28.item())+np.linalg.norm(e29- func.e29.item())+np.linalg.norm(e210- func.e210.item())+np.linalg.norm(e34- func.e34.item())+np.linalg.norm(e35- func.e35.item())+np.linalg.norm(e36- func.e36.item())+np.linalg.norm(e37- func.e37.item())+np.linalg.norm(e38- func.e38.item())+np.linalg.norm(e39- func.e39.item())+np.linalg.norm(e310- func.e310.item())+np.linalg.norm(e45- func.e45.item())+np.linalg.norm(e46- func.e46.item())+np.linalg.norm(e47- func.e47.item())+np.linalg.norm(e48- func.e48.item())+np.linalg.norm(e49- func.e49.item())+np.linalg.norm(e410- func.e410.item())+np.linalg.norm(e56- func.e56.item())+np.linalg.norm(e57- func.e57.item())+np.linalg.norm(e58- func.e58.item())+np.linalg.norm(e59- func.e59.item())+np.linalg.norm(e510- func.e510.item())+np.linalg.norm(e67- func.e67.item())+np.linalg.norm(e68- func.e68.item())+np.linalg.norm(e69- func.e69.item())+np.linalg.norm(e610- func.e610.item())+np.linalg.norm(e78- func.e78.item())+np.linalg.norm(e79- func.e79.item())+np.linalg.norm(e710- func.e710.item())+np.linalg.norm(e89- func.e89.item())+np.linalg.norm(e810- func.e810.item()) + np.linalg.norm(e910- func.e910.item()) + np.linalg.norm(a1- func.a1.item())+ np.linalg.norm(a2- func.a2.item())+ np.linalg.norm(a3- func.a3.item())+ np.linalg.norm(a4- func.a4.item())+ np.linalg.norm(a5- func.a5.item()) + np.linalg.norm(a6- func.a6.item())+ np.linalg.norm(a7- func.a7.item())+ np.linalg.norm(a8- func.a8.item())+ np.linalg.norm(a9- func.a9.item())+ np.linalg.norm(a10- func.a10.item()) + np.linalg.norm(b1- func.b1.item()) + np.linalg.norm(b2- func.b2.item())+ np.linalg.norm(b3 - func.b3.item()) + np.linalg.norm(b4- func.b4.item()) + np.linalg.norm(b5- func.b5.item()) + np.linalg.norm(b6- func.b6.item()) + np.linalg.norm(b7- func.b7.item()) + np.linalg.norm(b8- func.b8.item()) + np.linalg.norm(b9- func.b9.item()) + np.linalg.norm(b10- func.b10.item()))/65.0\n",
    "    \n",
    "print('Final error %.8f'% error)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e34892c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
