{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "import torch, time, sys\n",
    "import numpy as np\n",
    "import scipy.integrate\n",
    "solve_ivp = scipy.integrate.solve_ivp\n",
    "\n",
    "EXPERIMENT_DIR = '.'\n",
    "sys.path.append(EXPERIMENT_DIR)\n",
    "\n",
    "\n",
    "from data import get_dataset, coords2state, get_orbit, random_config, states2coords\n",
    "from data import potential_energy, kinetic_energy, total_energy, angular_momentum, momentum_x, momentum_y\n",
    "\n",
    "import modelloader\n",
    "\n",
    "from sklearn.linear_model import LinearRegression, Ridge, Lasso\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DPI = 300\n",
    "FORMAT = 'pdf'\n",
    "\n",
    "def get_args():\n",
    "    return {'input_dim': 2*4, # two bodies, each with q_x, q_y, p_z, p_y\n",
    "         'hidden_dim': 200,\n",
    "         'learn_rate': 1e-3,\n",
    "         'batch_size': 200,\n",
    "         't_points': 2000,\n",
    "         't_span': 50,\n",
    "         'total_steps': 50000,\n",
    "         'test_dim': 4,\n",
    "         'print_every': 100,\n",
    "         'verbose': True,\n",
    "         'name': '2body',\n",
    "         'seed': 2,\n",
    "         'save_dir': '{}'.format(EXPERIMENT_DIR),\n",
    "         'fig_dir': './figures'}\n",
    "\n",
    "class ObjectView(object):\n",
    "    def __init__(self, d): self.__dict__ = d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = ObjectView(get_args())\n",
    "np.random.seed(args.seed)\n",
    "\n",
    "model_scnn_1=modelloader.pnn_loader(args, test_dim=4, HPQ_trainable=True, number=1)\n",
    "model_scnn_2=modelloader.pnn_loader(args, test_dim=2, HPQ_trainable=True, number=2,num_hidden=0)\n",
    "model_scnn_3=modelloader.pnn_loader(args, test_dim=4, HPQ_trainable=True, number=3,num_hidden=2,  momentum=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully loaded data from ./2body-orbits-dataset.pkl\n"
     ]
    }
   ],
   "source": [
    "data = get_dataset(args.name, args.save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have the following mapping of the coordinates (by definition of the data):<br>\n",
    "x0 -> qx1<br>\n",
    "x1 -> qx2<br>\n",
    "x2 -> qy1<br>\n",
    "x3 -> qy2<br>\n",
    "x4 -> px1<br>\n",
    "x5 -> px2<br>\n",
    "x6 -> py1<br>\n",
    "x7 -> py2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first case with 4 cylic coordinates where non of them are fixed, one can already see the the momentum conservation in the first two coordinates directly. Additionally, one can see the angular momentum in the forth coordinate, but one has other factors included with non-vanishing prefactors. Therefore, coordinate three should include most parts of the Hamiltonian, but due to 1/x factors the function is not readable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coordinate:1\n",
      "x4 -4.21\n",
      "x5 -4.21\n",
      "x6 -1.26\n",
      "x7 -1.29\n",
      "____________________________________________________________________________________________________\n",
      "\n",
      "Coordinate:2\n",
      "x4 0.93\n",
      "x5 0.92\n",
      "x6 -3.23\n",
      "x7 -3.22\n",
      "____________________________________________________________________________________________________\n",
      "\n",
      "Coordinate:4\n",
      "x0 x6 -1.07\n",
      "x0 x7 0.88\n",
      "x1 x6 0.93\n",
      "x1 x7 -1.03\n",
      "x2 x4 1.01\n",
      "x2 x5 -0.89\n",
      "x3 x4 -0.92\n",
      "x3 x5 0.99\n",
      "____________________________________________________________________________________________________\n",
      "\n"
     ]
    }
   ],
   "source": [
    "x=torch.FloatTensor(data['test_coords'])\n",
    "y=model_scnn_1.forward_1(x)[0]\n",
    "y=y.detach().numpy()\n",
    "x=x.numpy()\n",
    "for j in range(4):\n",
    "    \n",
    "    y1=y[:,j:j+1]\n",
    "    for i in range(4):\n",
    "        \n",
    "        polynomial_features= PolynomialFeatures(degree=i)\n",
    "        x_poly = polynomial_features.fit_transform(x)\n",
    "        model_2= Lasso(alpha=.001)\n",
    "        model_2.fit(x_poly, y1)\n",
    "        y_poly_pred = model_2.predict(x_poly)\n",
    "        \n",
    "        rmse = np.sqrt(mean_squared_error(y1,y_poly_pred))\n",
    "        r2 = r2_score(y1,y_poly_pred)\n",
    "        if r2>0.9:\n",
    "            print('Coordinate:'+ str(j+1))\n",
    "            for k in range(np.shape(model_2.coef_.flatten())[0]):\n",
    "                if np.abs(model_2.coef_.flatten()[k])>0.1:\n",
    "                    print(polynomial_features.get_feature_names()[k],np.round(model_2.coef_.flatten()[k],2))\n",
    "            print('_'*100)\n",
    "            print()\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model shows that it is straight forward to find the conservation of momentum with the SCNNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coordinate 1\n",
      "x4 7.23\n",
      "x5 7.23\n",
      "____________________________________________________________________________________________________\n",
      "\n",
      "Coordinate 2\n",
      "x6 6.95\n",
      "x7 6.95\n",
      "____________________________________________________________________________________________________\n",
      "\n"
     ]
    }
   ],
   "source": [
    "x=torch.FloatTensor(data['coords'])\n",
    "y=model_scnn_2.forward_1(x)[0]\n",
    "y=y.detach().numpy()\n",
    "x=x.numpy()\n",
    "for j in range(2):\n",
    "    y1=y[:,j:j+1]\n",
    "    print('Coordinate '+ str(j+1))\n",
    "    for i in range(1,4):\n",
    "\n",
    "        polynomial_features= PolynomialFeatures(degree=i)\n",
    "        x_poly = polynomial_features.fit_transform(x)\n",
    "        #model_2= LinearRegression()\n",
    "        model_2= Lasso(alpha=.001)\n",
    "        model_2.fit(x_poly, y1)\n",
    "        y_poly_pred = model_2.predict(x_poly)\n",
    "        \n",
    "        rmse = np.sqrt(mean_squared_error(y1,y_poly_pred))\n",
    "        r2 = r2_score(y1,y_poly_pred)\n",
    "        if r2>0.95:\n",
    "            for k in range(np.shape(model_2.coef_.flatten())[0]):\n",
    "                if np.abs(model_2.coef_.flatten()[k])>0.1:\n",
    "                    print(polynomial_features.get_feature_names()[k],np.round(model_2.coef_.flatten()[k],2))\n",
    "            print('_'*100)\n",
    "            print()\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": false
   },
   "source": [
    "And here the SCNN detect the angular momentum as a conserved quantity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/software/opt/bionic/x86_64/python/3.7-2019.07/lib/python3.7/site-packages/sklearn/linear_model/coordinate_descent.py:475: ConvergenceWarning: Objective did not converge. You might want to increase the number of iterations. Duality gap: 10.422080993652344, tolerance: 0.026553429663181305\n",
      "  positive)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coordinate 4\n",
      "x0 x6 -0.97\n",
      "x0 x7 0.92\n",
      "x1 x6 0.93\n",
      "x1 x7 -0.97\n",
      "x2 x4 0.92\n",
      "x2 x5 -0.88\n",
      "x3 x4 -0.88\n",
      "x3 x5 0.92\n",
      "____________________________________________________________________________________________________\n",
      "\n"
     ]
    }
   ],
   "source": [
    "x=torch.FloatTensor(data['coords'])\n",
    "y=model_scnn_3.forward_1(x)[0]\n",
    "y=y.detach().numpy()\n",
    "x=x.numpy()\n",
    "for j in range(2,4):\n",
    "    y1=y[:,j:j+1]\n",
    "    for i in range(0,5):\n",
    "        polynomial_features= PolynomialFeatures(degree=i)\n",
    "        x_poly = polynomial_features.fit_transform(x)\n",
    "        model_2= Lasso(alpha=.001)\n",
    "        model_2.fit(x_poly, y1)\n",
    "        y_poly_pred = model_2.predict(x_poly)\n",
    "        \n",
    "        rmse = np.sqrt(mean_squared_error(y1,y_poly_pred))\n",
    "        r2 = r2_score(y1,y_poly_pred)\n",
    "        if r2>0.9:\n",
    "            print('Coordinate '+ str(j+1))\n",
    "\n",
    "            for k in range(np.shape(model_2.coef_.flatten())[0]):\n",
    "                if np.abs(model_2.coef_.flatten()[k])>0.1:\n",
    "                    print(polynomial_features.get_feature_names()[k],np.round(model_2.coef_.flatten()[k],2))\n",
    "            print('_'*100)\n",
    "            print()\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
