{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b9a24267",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy as sci\n",
    "from TorchDiffEqPack import odesolve_adjoint_sym12\n",
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df4c2b6e",
   "metadata": {},
   "source": [
    "parameter of the Kepler problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "12334409",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.pi/4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ea3f194",
   "metadata": {},
   "source": [
    "generation of the training data: initial point x0 and points on the trajectory projected to q-coordinate \n",
    "at times t1 = 0.2, t2 = 0.4, t3 = 0.6, t4 = 0.8, t5 = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e043f2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = [0.75,  0, 0, 0.9*(np.pi/4)*np.sqrt(5/3)]\n",
    "x0 = np.array(x0, dtype=\"float64\")\n",
    "\n",
    "\n",
    "init_params = np.array(x0)  \n",
    "init_params = init_params.flatten()  \n",
    "def KeplerEquations(w, t, a):\n",
    "    q = w[...,:2]\n",
    "    p = w[...,2:4]\n",
    "    \n",
    "\n",
    "    q12 = sci.linalg.norm(q)  \n",
    "    dqdt = p\n",
    "    dpdt = (-a/q12**(3))*q\n",
    "    \n",
    "    \n",
    "    \n",
    "   \n",
    "    derivs = np.concatenate((dqdt, dpdt))\n",
    "    return derivs\n",
    "\n",
    "q02 = np.zeros((2))\n",
    "q04 = np.zeros((2))\n",
    "q06 = np.zeros((2))\n",
    "q08 = np.zeros((2))\n",
    "q1 = np.zeros((2))\n",
    "\n",
    "\n",
    "\n",
    "# Run the ODE solver\n",
    "import scipy.integrate\n",
    "\n",
    "time_span = np.linspace(0.0, 0.2, 2000) \n",
    "sol02 = sci.integrate.odeint(KeplerEquations, init_params, time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "q02 = sol02[-1,0:2]\n",
    "                \n",
    "\n",
    "                \n",
    "time_span = np.linspace(0.2, 0.4, 2000) \n",
    "sol04 = sci.integrate.odeint(KeplerEquations, sol02[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "q04 = sol04[-1,0:2]\n",
    "                \n",
    "                \n",
    "time_span = np.linspace(0.4, 0.6, 2000)\n",
    "\n",
    "sol06 = sci.integrate.odeint(KeplerEquations, sol04[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "q06 = sol06[-1,0:2]\n",
    "                \n",
    "time_span = np.linspace(0.6, 0.8, 2000) \n",
    "\n",
    "sol08 = sci.integrate.odeint(KeplerEquations, sol06[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "q08 = sol08[-1,0:2]\n",
    "                \n",
    "time_span = np.linspace(0.8, 1.0, 2000) \n",
    "\n",
    "sol1 = sci.integrate.odeint(KeplerEquations, sol08[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "q1 = sol1[-1,0:2]\n",
    "                \n",
    "\n",
    "initial_condition = torch.from_numpy(x0).float()\n",
    "initial_condition = torch.unsqueeze(initial_condition, 0)\n",
    "\n",
    "traj_q02 = torch.from_numpy(q02).float()\n",
    "traj_q02 = torch.unsqueeze(traj_q02, 0)\n",
    "\n",
    "\n",
    "traj_q04 = torch.from_numpy(q04).float()\n",
    "traj_q04 = torch.unsqueeze(traj_q04, 0) \n",
    "\n",
    "traj_q06 = torch.from_numpy(q06).float()\n",
    "traj_q06 = torch.unsqueeze(traj_q06, 0) \n",
    "\n",
    "traj_q08 = torch.from_numpy(q08).float()\n",
    "traj_q08 = torch.unsqueeze(traj_q08, 0) \n",
    "\n",
    "traj_q1 = torch.from_numpy(q1).float()\n",
    "traj_q1 = torch.unsqueeze(traj_q1, 0) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ef906f1",
   "metadata": {},
   "source": [
    "torch function for the Kepler dynamics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc55e2f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchKeplerEquations(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TorchKeplerEquations, self).__init__()\n",
    "        \n",
    "        self.a = nn.Parameter(torch.ones(1)*0.1)\n",
    "        \n",
    "\n",
    "    def forward(self, t, w):\n",
    "        \n",
    "        q = w[...,:2]\n",
    "        p = w[...,2:4]\n",
    "\n",
    "        \n",
    "        q12 = torch.norm(q)  \n",
    "        dqdt = p\n",
    "        dpdt = (-self.a/q12**(3))*q\n",
    "\n",
    "       \n",
    "     \n",
    "        derivs = torch.cat((dqdt,dpdt), dim=-1)\n",
    "        return derivs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5911d4cd",
   "metadata": {},
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f2bfd222",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a 0.7853981633974483, estimated a 0.10000000149011612\n",
      "Epoch 1: Loss: 0.62450743\n",
      "a 0.7853981633974483, estimated a 0.3048591613769531\n",
      "Epoch 2: Loss: 0.34059983\n",
      "a 0.7853981633974483, estimated a 0.48041051626205444\n",
      "Epoch 3: Loss: 0.15203212\n",
      "a 0.7853981633974483, estimated a 0.6189231872558594\n",
      "Epoch 4: Loss: 0.04971118\n",
      "a 0.7853981633974483, estimated a 0.7122034430503845\n",
      "Epoch 5: Loss: 0.01031680\n",
      "a 0.7853981633974483, estimated a 0.7601274847984314\n",
      "Epoch 6: Loss: 0.00128593\n",
      "a 0.7853981633974483, estimated a 0.7778493165969849\n",
      "Epoch 7: Loss: 0.00011800\n",
      "a 0.7853981633974483, estimated a 0.7831444144248962\n",
      "Epoch 8: Loss: 0.00001112\n",
      "a 0.7853981633974483, estimated a 0.7847052812576294\n",
      "Epoch 9: Loss: 0.00000124\n",
      "a 0.7853981633974483, estimated a 0.7852001786231995\n",
      "Epoch 10: Loss: 0.00000016\n",
      "a 0.7853981633974483, estimated a 0.7853711843490601\n",
      "Epoch 11: Loss: 0.00000004\n",
      "a 0.7853981633974483, estimated a 0.7854477763175964\n",
      "Epoch 12: Loss: 0.00000000\n",
      "Finished training\n",
      "--- 3.0147175788879395 seconds ---\n"
     ]
    }
   ],
   "source": [
    "func = TorchKeplerEquations()\n",
    "\n",
    "\n",
    "options = {}\n",
    "options.update({'method': 'yoshida_alf2'})#'fixedstep_yoshida_alf2'}) fixedstep_sym12async suzuki_alf2\n",
    "options.update({'h': 0.01})\n",
    "options.update({'t0': 0.0})\n",
    "options.update({'t1': 1.0})\n",
    "options.update({'rtol': 1e-4})\n",
    "options.update({'atol': 1e-5})\n",
    "options.update({'print_neval': False})\n",
    "options.update({'neval_max': 1000000})\n",
    "options.update({'safety': None})\n",
    "\n",
    "options.update({'interpolation_method':'cubic'})\n",
    "options.update({'regenerate_graph':False})\n",
    "\n",
    "\n",
    "optimizer = torch.optim.SGD(func.parameters(),lr=0.1) \n",
    "\n",
    "lr = 1e-1\n",
    "lr_decay = 0.95\n",
    "def adjust_learning_rate(optimizer, lr):\n",
    "    for param_group in optimizer.param_groups:\n",
    "        lr_old = param_group['lr']\n",
    "        param_group['lr'] = lr\n",
    "\n",
    "\n",
    "\n",
    "  \n",
    "func.train()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "l = 10.0\n",
    "i = 0\n",
    "\n",
    "import time\n",
    "\n",
    "start_time = time.time()\n",
    "while l > 1e-8 and i<600:\n",
    "    \n",
    "    print('a {}, estimated a {}'.format(a, func.a.item()))\n",
    "\n",
    "    i = i+1\n",
    "    lr *= lr_decay\n",
    "    adjust_learning_rate(optimizer, lr)\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "\n",
    "    func.eval()\n",
    "    time_span = np.linspace(0, 0.2, 2000)  \n",
    "    t_list = time_span.tolist()\n",
    "    options.update({'t_eval':t_list})\n",
    "    options.update({'t0': 0.0})\n",
    "    options.update({'t1': 0.2})\n",
    "\n",
    "    out = odesolve_adjoint_sym12(func, initial_condition, options=options)\n",
    "        \n",
    "        \n",
    "        \n",
    "    time_span = np.linspace(0.2, 0.4, 2000)  \n",
    "    t_list = time_span.tolist()\n",
    "    options.update({'t_eval':t_list})\n",
    "    options.update({'t0': 0.2})\n",
    "    options.update({'t1': 0.4})\n",
    "    out2 = odesolve_adjoint_sym12(func, out, options=options)\n",
    "        \n",
    "        \n",
    "    time_span = np.linspace(0.4, 0.6, 2000)  \n",
    "    t_list = time_span.tolist()\n",
    "    options.update({'t_eval':t_list})\n",
    "    options.update({'t0': 0.4})\n",
    "    options.update({'t1': 0.6})\n",
    "    out3 = odesolve_adjoint_sym12(func, out2, options=options)\n",
    "        \n",
    "        \n",
    "    time_span = np.linspace(0.6, 0.8, 2000)  \n",
    "    t_list = time_span.tolist()\n",
    "    options.update({'t_eval':t_list})\n",
    "    options.update({'t0': 0.6})\n",
    "    options.update({'t1': 0.8})\n",
    "    out4 = odesolve_adjoint_sym12(func, out3, options=options)\n",
    "        \n",
    "    time_span = np.linspace(0.8, 1.0, 2000) \n",
    "    t_list = time_span.tolist()\n",
    "    options.update({'t_eval':t_list})\n",
    "    options.update({'t0': 0.8})\n",
    "    options.update({'t1': 1.0})\n",
    "    out5 = odesolve_adjoint_sym12(func, out4, options=options)\n",
    "        \n",
    "        \n",
    "        \n",
    "    position = out[..., :2]\n",
    "    position2 = out2[..., :2]\n",
    "    position3 = out3[..., :2]\n",
    "    position4 = out4[..., :2]\n",
    "    position5 = out5[..., :2]\n",
    "          \n",
    "        \n",
    "        \n",
    "    loss = torch.norm(position - traj_q02[0])**2\n",
    "    loss2 = torch.norm(position2 - traj_q04[0])**2\n",
    "    loss3 = torch.norm(position3 - traj_q06[0])**2\n",
    "    loss4 = torch.norm(position4 - traj_q08[0])**2\n",
    "    loss5 = torch.norm(position5 - traj_q1[0])**2\n",
    "        \n",
    "    loss =+ loss2 + loss3 + loss4 + loss5\n",
    "    l = torch.norm(loss).item()\n",
    "        \n",
    "        \n",
    "       \n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "    print('Epoch %d: Loss: %.8f' % (i, loss.item()))\n",
    "print('Finished training')\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11772177",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
