{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c8cc91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy as sci\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from TorchDiffEqPack import odesolve_adjoint_sym12\n",
    "import torch\n",
    "\n",
    "from torch import nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5caec1f0",
   "metadata": {},
   "source": [
    "parameter of the Kepler problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cfe9b67",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.pi/4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29370026",
   "metadata": {},
   "source": [
    "Initial data on a grid around point x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2955a259",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = [0.75,  0, 0, 0.9*(np.pi/4)*np.sqrt(5/3)]\n",
    "x0 = np.array(x0, dtype=\"float64\")\n",
    "\n",
    "IniD = np.zeros((3,3,3,3,4))\n",
    "InitialData = IniD.tolist()\n",
    "IniStep = 0.1\n",
    "\n",
    "for x1 in range(-1,2):\n",
    "    for x2 in range(-1,2):\n",
    "        for x3 in range(-1,2):\n",
    "            for x4 in range(-1,2):\n",
    "            \n",
    "                i1 = [IniStep, 0.0, 0.0, 0.0] \n",
    "                i2 = [0.0, IniStep, 0.0, 0.0]\n",
    "                i3 = [0.0, 0.0, IniStep, 0.0]\n",
    "                i4 = [0.0, 0.0, 0.0, IniStep]\n",
    "                \n",
    "                InitialData[1+x1][1+x2][1+x3][1+x4]= [_x0 + x1*_i1 + x2*_i2 + x3*_i3 + x4*_i4 for _x0,_i1,_i2,_i3,_i4 in zip(x0,i1,i2,i3,i4)]\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97abb962",
   "metadata": {},
   "source": [
    "Trajectory points for all the initial data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c25abf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def KeplerEquations(w, t, a):\n",
    "    q = w[...,:2]\n",
    "    p = w[...,2:4]\n",
    "    \n",
    "\n",
    "    q12 = sci.linalg.norm(q)  \n",
    "    dqdt = p\n",
    "    dpdt = (-a/q12**(3))*q\n",
    "    \n",
    "    \n",
    "    \n",
    "   \n",
    "    derivs = np.concatenate((dqdt, dpdt))\n",
    "    return derivs\n",
    "\n",
    "import scipy.integrate\n",
    "\n",
    "q02 = np.zeros((3,3,3,3,2))\n",
    "q04 = np.zeros((3,3,3,3,2))\n",
    "q06 = np.zeros((3,3,3,3,2))\n",
    "q08 = np.zeros((3,3,3,3,2))\n",
    "q1 = np.zeros((3,3,3,3,2))\n",
    "for x1 in range(-1,2):\n",
    "    for x2 in range(-1,2):\n",
    "        for x3 in range(-1,2):\n",
    "            for x4 in range(-1,2):\n",
    "                \n",
    "                init_params = np.array(InitialData[1+x1][1+x2][1+x3][1+x4], dtype=\"float64\")\n",
    "                init_params = init_params.flatten()  \n",
    "                \n",
    "                time_span = np.linspace(0.0, 0.2, 2000) \n",
    "                sol02 = sci.integrate.odeint(KeplerEquations, init_params, time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "                q02[1+x1,1+x2,1+x3,1+x4] = sol02[-1,0:2]\n",
    "                                \n",
    "\n",
    "                                \n",
    "                time_span = np.linspace(0.2, 0.4, 2000) \n",
    "                sol04 = sci.integrate.odeint(KeplerEquations, sol02[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "                q04[1+x1,1+x2,1+x3,1+x4] = sol04[-1,0:2]\n",
    "                                \n",
    "                                \n",
    "                time_span = np.linspace(0.4, 0.6, 2000)\n",
    "\n",
    "                sol06 = sci.integrate.odeint(KeplerEquations, sol04[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "                q06[1+x1,1+x2,1+x3,1+x4] = sol06[-1,0:2]\n",
    "                                \n",
    "                time_span = np.linspace(0.6, 0.8, 2000) \n",
    "\n",
    "                sol08 = sci.integrate.odeint(KeplerEquations, sol06[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "                q08[1+x1,1+x2,1+x3,1+x4] = sol08[-1,0:2]\n",
    "                                \n",
    "                time_span = np.linspace(0.8, 1.0, 2000) \n",
    "\n",
    "                sol1 = sci.integrate.odeint(KeplerEquations, sol08[-1,:], time_span, args=(a,), rtol=1e-7, atol=1e-8, hmax=1e-5)\n",
    "                q1[1+x1,1+x2,1+x3,1+x4] = sol1[-1,0:2]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c860b72f",
   "metadata": {},
   "source": [
    "Torch function for the Kepler problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7905986",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchKeplerEquations(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TorchKeplerEquations, self).__init__()\n",
    "        \n",
    "        self.a = nn.Parameter(torch.ones(1)*0.1)\n",
    "        \n",
    "\n",
    "    def forward(self, t, w):\n",
    "        \n",
    "        q = w[...,:2]\n",
    "        p = w[...,2:4]\n",
    "\n",
    "        \n",
    "        q12 = torch.norm(q)  \n",
    "        dqdt = p\n",
    "        dpdt = (-self.a/q12**(3))*q\n",
    "\n",
    "        \n",
    "     \n",
    "        derivs = torch.cat((dqdt,dpdt), dim=-1)\n",
    "        return derivs\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "func = TorchKeplerEquations()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f38a8bd",
   "metadata": {},
   "source": [
    "Transformation of the training data into torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c64b1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_q02 = torch.from_numpy(q02).float()\n",
    "traj_q02 = torch.unsqueeze(traj_q02, 0)\n",
    "\n",
    "\n",
    "traj_q04 = torch.from_numpy(q04).float()\n",
    "traj_q04 = torch.unsqueeze(traj_q04, 0) \n",
    "\n",
    "traj_q06 = torch.from_numpy(q06).float()\n",
    "traj_q06 = torch.unsqueeze(traj_q06, 0) \n",
    "\n",
    "traj_q08 = torch.from_numpy(q08).float()\n",
    "traj_q08 = torch.unsqueeze(traj_q08, 0) \n",
    "\n",
    "traj_q1 = torch.from_numpy(q1).float()\n",
    "traj_q1 = torch.unsqueeze(traj_q1, 0) \n",
    "\n",
    "\n",
    "InitialData = np.array(InitialData, dtype=\"float64\")\n",
    "initial_condition = torch.from_numpy(InitialData).float()\n",
    "initial_condition = torch.unsqueeze(initial_condition, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13294e6d",
   "metadata": {},
   "source": [
    "Options for integration "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "403d0306",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = {}\n",
    "\n",
    "options.update({'h': 0.01})\n",
    "\n",
    "options.update({'rtol': 1e-4})\n",
    "options.update({'atol': 1e-5})\n",
    "options.update({'print_neval': False})\n",
    "options.update({'neval_max': 1000000})\n",
    "options.update({'safety': None})\n",
    "options.update({'interpolation_method':'cubic'})\n",
    "options.update({'regenerate_graph':False})"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e76e8846",
   "metadata": {},
   "source": [
    "Computation of the loss for 300 values of the parameter alpha around the true value pi/4 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4b10b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "valuesALF =  np.zeros((300))\n",
    "valuesYoshida = np.zeros((300))\n",
    "steps = np.zeros((300))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "l = 0.0\n",
    "\n",
    "\n",
    "for k in range(300):\n",
    "    options.update({'method': 'fixedstep_sym12async'})\n",
    "    \n",
    "    func.a = torch.nn.Parameter(torch.ones(1)*(np.pi/4 - (150-k)*0.000001))\n",
    "    \n",
    "    \n",
    "    for x1 in range(-1,2):\n",
    "        for x2 in range(-1,2):\n",
    "            for x3 in range(-1,2):\n",
    "                for x4 in range(-1,2):\n",
    "                    time_span = np.linspace(0, 0.2, 2000)  # 20 orbital periods and 500 points\n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.0})\n",
    "                    options.update({'t1': 0.2})\n",
    "                    out = odesolve_adjoint_sym12(func, initial_condition[0,1+x1,1+x2,1+x3,1+x4], options=options)#, time_points=t_list)\n",
    "                    \n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.2, 0.4, 2000) \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.2})\n",
    "                    options.update({'t1': 0.4})\n",
    "                    out2 = odesolve_adjoint_sym12(func, out, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.4, 0.6, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.4})\n",
    "                    options.update({'t1': 0.6})\n",
    "                    out3 = odesolve_adjoint_sym12(func, out2, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.6, 0.8, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.6})\n",
    "                    options.update({'t1': 0.8})\n",
    "                    out4 = odesolve_adjoint_sym12(func, out3, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.8, 1.0, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.8})\n",
    "                    options.update({'t1': 1.0})\n",
    "                    out5 = odesolve_adjoint_sym12(func, out4, options=options)\n",
    "                    \n",
    "                    \n",
    "                    \n",
    "                    position = out[..., :2]\n",
    "                    position2 = out2[..., :2]\n",
    "                    position3 = out3[..., :2]\n",
    "                    position4 = out4[..., :2]\n",
    "                    position5 = out5[..., :2]\n",
    "                    \n",
    "                    loss = torch.norm(position - traj_q02[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss2 = torch.norm(position2 - traj_q04[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss3 = torch.norm(position3 - traj_q06[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss4 = torch.norm(position4 - traj_q08[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss5 = torch.norm(position5 - traj_q1[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    \n",
    "                    \n",
    "                    loss =+ loss2 + loss3 + loss4 + loss5\n",
    "                    l =+ torch.norm(loss).item()\n",
    "                    \n",
    "    l = l/(81)\n",
    "    valuesALF[k] = l    \n",
    "    steps[k] = (np.pi/4 - (150-k)*0.000001)\n",
    "    \n",
    "l = 0.0    \n",
    "for k in range(300):\n",
    "    options.update({'method': 'fixedstep_yoshida_alf2'})   \n",
    "    \n",
    "    func.a = torch.nn.Parameter(torch.ones(1)*(np.pi/4 - (150-k)*0.000001))\n",
    "    \n",
    "    #print('a {}, estimated a {}'.format(a, func.a.item()))\n",
    "    for x1 in range(-1,2):\n",
    "        for x2 in range(-1,2):\n",
    "            for x3 in range(-1,2):\n",
    "                for x4 in range(-1,2):\n",
    "                    time_span = np.linspace(0, 0.2, 2000)  # 20 orbital periods and 500 points\n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.0})\n",
    "                    options.update({'t1': 0.2})\n",
    "                    out = odesolve_adjoint_sym12(func, initial_condition[0,1+x1,1+x2,1+x3,1+x4], options=options)#, time_points=t_list)\n",
    "                    \n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.2, 0.4, 2000) \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.2})\n",
    "                    options.update({'t1': 0.4})\n",
    "                    out2 = odesolve_adjoint_sym12(func, out, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.4, 0.6, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.4})\n",
    "                    options.update({'t1': 0.6})\n",
    "                    out3 = odesolve_adjoint_sym12(func, out2, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.6, 0.8, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.6})\n",
    "                    options.update({'t1': 0.8})\n",
    "                    out4 = odesolve_adjoint_sym12(func, out3, options=options)\n",
    "                    \n",
    "                    \n",
    "                    time_span = np.linspace(0.8, 1.0, 2000)  \n",
    "                    t_list = time_span.tolist()\n",
    "                    options.update({'t_eval':t_list})\n",
    "                    options.update({'t0': 0.8})\n",
    "                    options.update({'t1': 1.0})\n",
    "                    out5 = odesolve_adjoint_sym12(func, out4, options=options)\n",
    "                    \n",
    "                    \n",
    "                    \n",
    "                    position = out[..., :2]\n",
    "                    position2 = out2[..., :2]\n",
    "                    position3 = out3[..., :2]\n",
    "                    position4 = out4[..., :2]\n",
    "                    position5 = out5[..., :2]\n",
    "                    \n",
    "                    loss = torch.norm(position - traj_q02[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss2 = torch.norm(position2 - traj_q04[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss3 = torch.norm(position3 - traj_q06[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss4 = torch.norm(position4 - traj_q08[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    loss5 = torch.norm(position5 - traj_q1[0,1+x1,1+x2,1+x3,1+x4])**2\n",
    "                    \n",
    "                    \n",
    "                    loss =+ loss2 + loss3 + loss4 + loss5\n",
    "                    l =+ torch.norm(loss).item()\n",
    "                    \n",
    "    l = l/(81)\n",
    "    valuesYoshida[k] =l\n",
    "    \n",
    "np.savetxt('steps.txt', steps, delimiter=',', newline='\\n' )  \n",
    "np.savetxt('ALF.txt', valuesALF, delimiter=',', newline='\\n' )  \n",
    "np.savetxt('Yoshida.txt', valuesYoshida, delimiter=',', newline='\\n' )  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4820961c",
   "metadata": {},
   "source": [
    "Plot of the obtained loss values, which form the error landscape for ALF and Yoshida composition of ALF2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56e1e86",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 15))\n",
    "\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "ax.plot(steps,valuesALF, label='ALF')\n",
    "ax.plot(steps,valuesYoshida, label='Yoshida')\n",
    "\n",
    "\n",
    "\n",
    "plt.axhline(y = 0, color = 'black')\n",
    "plt.axvline(x = a, color = 'purple', label = 'true value')\n",
    "\n",
    "ax.set_xlabel(\"parameter\", fontsize=14)\n",
    "ax.set_ylabel(\"loss\", fontsize=14)\n",
    "\n",
    "ax.legend(loc=\"upper left\", fontsize=14)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
