{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1861de8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy as sci\n",
    "\n",
    "from TorchDiffEqPack import odesolve_adjoint_sym12\n",
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "\n",
    "from skopt.space import Space\n",
    "from skopt.sampler import Halton\n",
    "import random \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dbb9b73",
   "metadata": {},
   "source": [
    "Constrants for the coupled oscillators problem with 2 particles. The last example of the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1ddee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "a1 = 2.0\n",
    "a2 = 0.7\n",
    "b1 = -0.4\n",
    "b2 = 3.0\n",
    "e = 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7cc54cb",
   "metadata": {},
   "source": [
    "Initial and final points of 1000 trajectories obtained using Halton sequence and integration of the true dynamics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d4a5c5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = [0.8,  -0.4, 0.0, 0.0]\n",
    "x0 = np.array(x0, dtype=\"float64\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "space = Space([(x0[0]-1.0, x0[0]+1.0), (x0[1]-1.0, x0[1]+1.0), (x0[2]-1.0, x0[2]+1.0), (x0[3]-1.0, x0[3]+1.0)])\n",
    "\n",
    "n_samples = 1000\n",
    "halton = Halton()\n",
    "start = halton.generate(space.dimensions, n_samples)\n",
    "start_n = np.array(start, dtype=\"float64\")\n",
    "initial_data = torch.from_numpy(np.array(start_n)).float()\n",
    "\n",
    "\n",
    "def DuffingEquations(w, t, a1, a2, b1, b2, e):\n",
    "    q = w[...,:2]\n",
    "    p = w[...,2:4]\n",
    "   \n",
    "\n",
    "    \n",
    "    dqdt = p\n",
    "    dp1dt = - a1*q[0] - b1*q[0]**3 + e*(q[1] - q[0])   \n",
    "    dp2dt = - a2*q[1] - b2*q[1]**3 - e*(q[1] - q[0])   \n",
    "    \n",
    "    \n",
    "    \n",
    "    dpdt = np.array([dp1dt, dp2dt])\n",
    "    \n",
    "   \n",
    "    derivs = np.concatenate((dqdt, dpdt))\n",
    "    return derivs\n",
    "\n",
    "\n",
    "import scipy.integrate\n",
    "\n",
    "\n",
    "\n",
    "q01 = np.zeros((n_samples,2000,4))\n",
    "\n",
    "\n",
    "for x1 in range(n_samples):\n",
    "    init_params = start[x1]\n",
    "    init_params = np.array(init_params, dtype=\"float64\")\n",
    "    init_params = init_params.flatten()\n",
    "    \n",
    "    three_body_sol01 = sci.integrate.odeint(DuffingEquations, init_params, time_span, args=(a1,a2,b1,b2,e), rtol=1e-13, atol=1e-14)\n",
    "\n",
    "    q01[x1] = three_body_sol01[:,0:4]\n",
    "    \n",
    "    \n",
    "traj_q01 = torch.from_numpy(q01).float()\n",
    "traj_q01 = torch.unsqueeze(traj_q01, 0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fdc8732",
   "metadata": {},
   "source": [
    "Torch function, where the derivatives of the 5 unknown potentials are parameterized as neural networks dV1, dV2, dV3, dV4, dV5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe0f1b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchDuffingEquations(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TorchDuffingEquations, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.nout = 100\n",
    "        \n",
    "        self.dV1 = nn.Sequential(\n",
    "        nn.Linear(1, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, 1))\n",
    "        \n",
    "        self.dV2 = nn.Sequential(\n",
    "        nn.Linear(1, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, 1))\n",
    "        \n",
    "        self.dV3 = nn.Sequential(\n",
    "        nn.Linear(1, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, 1))\n",
    "        \n",
    "        self.dV4 = nn.Sequential(\n",
    "        nn.Linear(1, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, 1))\n",
    "        \n",
    "        self.dV5 = nn.Sequential(\n",
    "        nn.Linear(1, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, self.nout),\n",
    "        nn.Tanh(),\n",
    "        nn.Linear(self.nout, 1))\n",
    "        \n",
    "    \n",
    "    def getdV1(self, w):\n",
    "        \n",
    "        return self.dV1(torch.reshape(w,(1,)))\n",
    "    def getdV2(self, w):\n",
    "        \n",
    "        return self.dV2(torch.reshape(w,(1,)))\n",
    "    def getdV3(self, w):\n",
    "        \n",
    "        return self.dV3(torch.reshape(w,(1,)))\n",
    "    def getdV4(self, w):\n",
    "        \n",
    "        return self.dV4(torch.reshape(w,(1,)))\n",
    "    def getdV5(self, w):\n",
    "        \n",
    "        return self.dV5(torch.reshape(w,(1,)))\n",
    "    \n",
    "\n",
    "    \n",
    "    \n",
    "    def forward(self, t, w):\n",
    "        \n",
    "        q = w[...,:2]\n",
    "        p = w[...,2:4]\n",
    "        \n",
    "        A=torch.ones(1, 2)\n",
    "        A[0,0] =-self.dV1(torch.reshape(q[0],(1,))) - self.dV2(torch.reshape(q[0],(1,))) + self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        A[0,1] = -self.dV3(torch.reshape(q[1],(1,))) - self.dV4(torch.reshape(q[1],(1,))) - self.dV5(torch.reshape(q[1],(1,))-torch.reshape(q[0],(1,)))\n",
    "        \n",
    "        \n",
    "        dqdt = p\n",
    "        dpdt = torch.reshape(A,(-1,))\n",
    "        \n",
    "        \n",
    "        derivs = torch.cat((dqdt,dpdt), dim=-1)\n",
    "        return derivs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1531974",
   "metadata": {},
   "outputs": [],
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23bbcf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "func = TorchDuffingEquations()\n",
    "t_list = time_span.tolist()\n",
    "\n",
    "\n",
    "\n",
    "options = {}\n",
    "options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async suzuki_alf2\n",
    "options.update({'h': 0.1})\n",
    "options.update({'t0': 0.0})\n",
    "options.update({'t1': 0.5})\n",
    "options.update({'rtol': 1e-4})\n",
    "options.update({'atol': 1e-5})\n",
    "options.update({'print_neval': False})\n",
    "options.update({'neval_max': 1000000})\n",
    "options.update({'safety': None})\n",
    "options.update({'t_eval':None})\n",
    "options.update({'interpolation_method':'cubic'})\n",
    "options.update({'regenerate_graph':True})\n",
    "\n",
    "lr = 1e-3\n",
    "optimizer = torch.optim.AdamW(func.parameters(), lr=lr, betas=(0.50, 0.50), eps=1e-08, weight_decay=0.01, amsgrad=False, maximize=False, foreach=None, capturable=False, differentiable=False, fused=None)\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)\n",
    "\n",
    "\n",
    "\n",
    "torch.manual_seed(1229)\n",
    "random.seed(1230)\n",
    "np.random.seed(1234)\n",
    "\n",
    "func.train()\n",
    "\n",
    "\n",
    "NbTraj = 300\n",
    "\n",
    "best_loss = np.inf\n",
    "import time\n",
    "\n",
    "loss = 10.0\n",
    "i = 0\n",
    "\n",
    "\n",
    "start_time = time.time()\n",
    "while loss > 1e-2 and i<300:\n",
    "    i = i+1\n",
    "        \n",
    "        \n",
    "  \n",
    "    func.eval()\n",
    "        \n",
    "    x1 = random.randint(0,n_samples-1)\n",
    "        \n",
    "    out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)\n",
    "        \n",
    "        \n",
    "        \n",
    "    position01 = out01[..., :4]\n",
    "    dif01 = position01 - traj_q01[0,x1,-1]\n",
    "    dif01 = torch.sum(dif01 ** 2, -1, keepdim=False) \n",
    "        \n",
    "    dif101 = torch.squeeze(dif01)  # N\n",
    "    l =torch.sum(torch.abs(dif101))\n",
    "    loss = torch.norm(l).item()\n",
    "        \n",
    "    for k in range(NbTraj-1):\n",
    "        x1 = random.randint(0,n_samples-1)\n",
    "            \n",
    "            \n",
    "        out01 = odesolve_adjoint_sym12(func, initial_data[x1], options=options)\n",
    "            \n",
    "            \n",
    "            \n",
    "        position01 = out01[..., :4]\n",
    "        dif01 = position01 - traj_q01[0,x1,-1]\n",
    "        dif01 = torch.sum(dif01 ** 2, -1, keepdim=False)  # 1 x N\n",
    "            \n",
    "        dif101 = torch.squeeze(dif01)  # N\n",
    "        l =l + torch.sum(torch.abs(dif101))\n",
    "        loss = loss + torch.norm(l).item()\n",
    "            \n",
    "            \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "    l = l/float(NbTraj)\n",
    "        \n",
    "        \n",
    "\n",
    "        \n",
    "    l.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "        \n",
    "    print('Epoch %d: Loss: %.8f' % (i, l.item()))\n",
    "print('Finished training')\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
