{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Training the Hypersolver"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "source": [
    "import os, sys\r\n",
    "import matplotlib.pyplot as plt\r\n",
    "import numpy as np\r\n",
    "import torch\r\n",
    "import torch.nn as nn\r\n",
    "import sys; sys.path.append(2*'../') # go 2 dirs back\r\n",
    "from src import *\r\n",
    "from math import pi\r\n",
    "import matplotlib\r\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "source": [
    "device = 'cpu'"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "source": [
    "ControlledSystem = QuadcopterGym\r\n",
    "\r\n",
    "# ## Scaling\r\n",
    "# out_scal = torch.ones(21).to(device)\r\n",
    "# for i in range(4):\r\n",
    "#     out_scal[i+9] = 1/500\r\n",
    "#     out_scal[i+13] = 1/10000\r\n",
    "\r\n",
    "# in_scal = torch.cat([out_scal, out_scal])\r\n",
    "# u_scal = 1/100*torch.ones(4).to(device)\r\n",
    "# in_scal = torch.cat([in_scal, u_scal])\r\n",
    "\r\n",
    "# print('Input scaling:\\n', in_scal)\r\n",
    "# print('Output scaling:\\n', out_scal)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Generate distribution of initial conditions"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "source": [
    "from torch.distributions import Uniform\r\n",
    "\r\n",
    "low = -torch.ones(12).to(device)\r\n",
    "high = torch.ones(12).to(device)\r\n",
    "low[0:3]*=5\r\n",
    "high[0:3]*=5\r\n",
    "low[3:6]*=50\r\n",
    "high[3:6]*=50\r\n",
    "low[6:9]*=50\r\n",
    "high[6:9]*=50\r\n",
    "low[9:12]*=100\r\n",
    "high[9:12]*=100\r\n",
    "\r\n",
    "dist = Uniform(low, high)\r\n",
    "print(dist.sample_n(2))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tensor([[ -4.5748,   0.9163,   4.8489, -20.9084, -19.7354,  26.2020, -17.7314,\n",
      "          27.5852, -12.5779, -55.1637,  78.2594, -54.2277],\n",
      "        [  4.9443,  -0.3451,  -4.4303, -42.4282,  -2.7365, -12.2402,  10.2114,\n",
      "         -14.0203,  30.4343,  95.8709,  41.5496, -28.9956]])\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "/home/botu/anaconda3/lib/python3.8/site-packages/torch/distributions/distribution.py:161: UserWarning: sample_n will be deprecated. Use .sample((n,)) instead\n",
      "  warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## New output scaling\n",
    "Since position and angular positions derivatives $\\dot{x},~\\dot{\\psi}$ propagated via the augmented state without any calculation, then there is no need to add an additional term. Indeed, doing so would just degrade the performance unless it was always 0. So we put the output scaling to 0 for updates on positions and angular positions"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "source": [
    "in_scal = torch.cat([1/high, 1/high])\r\n",
    "u_scal = 1/20000*torch.ones(4).to(device)\r\n",
    "in_scal = torch.cat([in_scal, u_scal])\r\n",
    "\r\n",
    "out_scal = high\r\n",
    "\r\n",
    "# Zero out positions contribution term\r\n",
    "out_scal[0:6] = 0 \r\n",
    "print('Input scaling:\\n', in_scal)\r\n",
    "print('Output scaling:\\n', out_scal)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Input scaling:\n",
      " tensor([2.0000e-01, 2.0000e-01, 2.0000e-01, 2.0000e-02, 2.0000e-02, 2.0000e-02,\n",
      "        2.0000e-02, 2.0000e-02, 2.0000e-02, 1.0000e-02, 1.0000e-02, 1.0000e-02,\n",
      "        2.0000e-01, 2.0000e-01, 2.0000e-01, 2.0000e-02, 2.0000e-02, 2.0000e-02,\n",
      "        2.0000e-02, 2.0000e-02, 2.0000e-02, 1.0000e-02, 1.0000e-02, 1.0000e-02,\n",
      "        5.0000e-05, 5.0000e-05, 5.0000e-05, 5.0000e-05])\n",
      "Output scaling:\n",
      " tensor([  0.,   0.,   0.,   0.,   0.,   0.,  50.,  50.,  50., 100., 100., 100.])\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Scaling\n",
    "We could apply a residual loss with scaling so to give the same \"importance\" to all of the parameters"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "source": [
    "Δt = 0.02\r\n",
    "\r\n",
    "x_min_train, x_max_train = -1000, 1000\r\n",
    "\r\n",
    "bs_hs = 1024\r\n",
    "n_grid = 100\r\n",
    "\r\n",
    "def residual_loss(g, X, u, Δt, method, order=1):\r\n",
    "    t_span = torch.tensor([0, Δt]).to(device)\r\n",
    "    x_fine = odeint(ControlledSystem(u)._dynamics, X, t_span, method='dopri5')[-1]\r\n",
    "    x_coarse = odeint(ControlledSystem(u)._dynamics, X, t_span, method=method)[-1]\r\n",
    "    xfu = torch.cat([X, ControlledSystem(u)._dynamics(0, X), u.u0], -1)\r\n",
    "#     return torch.norm(x_fine - x_coarse - (Δt**(order+1))*g(xfu), p=2, dim=-1)/(Δt**(order+1))\r\n",
    "###########################################################################################################################################################\r\n",
    "# Scaling \r\n",
    "    l = torch.norm((x_fine - x_coarse - (Δt**(order+1))*g(0, xfu))/(high-low), p=2, dim=-1)/(Δt**(order+1))\r\n",
    "    return l"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Hypersolver"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "source": [
    "\r\n",
    "hdim = 64\r\n",
    "\r\n",
    "class Hypersolver(nn.Module):\r\n",
    "    def __init__(self):\r\n",
    "        super().__init__()\r\n",
    "        layers = [nn.Linear(28, hdim),\r\n",
    "            nn.Softplus(),\r\n",
    "        #     Snake(hdim),\r\n",
    "            nn.Linear(hdim, hdim),\r\n",
    "            nn.Softplus(),\r\n",
    "        #     Snake(hdim),\r\n",
    "            nn.Linear(hdim, 12)]\r\n",
    "        self.layers = nn.Sequential(*layers)\r\n",
    "        \r\n",
    "    def forward(self, t, x):\r\n",
    "        x = x*in_scal\r\n",
    "        x = self.layers(x)\r\n",
    "        x = x*out_scal\r\n",
    "        return x\r\n",
    "    \r\n",
    "hs = Hypersolver().to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "# hs = torch.load('saved_models/hypersolver_0.02_new_quadcopter.pt')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Train"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "source": [
    "drone = ControlledSystem(None)\r\n",
    "max_controller = drone.MAX_RPM"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "source": [
    "from tqdm import trange\r\n",
    "\r\n",
    "bs_hs = 1024\r\n",
    "epochs = 100000\r\n",
    "losses = []\r\n",
    "opt_hs = torch.optim.Adam(hs.parameters(), lr=1e-3)\r\n",
    "\r\n",
    "with trange(0, epochs, desc=\"Epochs\") as epcs:\r\n",
    "    for epoch in epcs:\r\n",
    "        '''Hypersolver training loop with given Δt_test'''\r\n",
    "        opt_hs.zero_grad()\r\n",
    "        \r\n",
    "        # Randomize state and controller values\r\n",
    "        x = dist.sample((bs_hs,)).to(device)\r\n",
    "        u_rand = RandConstController((3, 3), 1, 1).to(device)\r\n",
    "        u_rand.u0 = torch.Tensor(bs_hs, 4)[None].uniform_(0, max_controller).to(device)\r\n",
    "\r\n",
    "        loss = residual_loss(hs, x[None], u_rand,  Δt, 'euler').mean()\r\n",
    "        # Optimization step\r\n",
    "        loss.backward()\r\n",
    "        opt_hs.step()\r\n",
    "        losses.append(loss.detach().cpu().item())\r\n",
    "        # print(f\"Epoch:{epoch}, loss:%.3e\" % loss, end=\"\\r\")\r\n",
    "        epcs.set_postfix(loss=(loss.detach().cpu().item()))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Epochs:  28%|██▊       | 27901/100000 [18:47<48:32, 24.75it/s, loss=21.6]\n"
     ]
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-44-4bed65f2710f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mu_rand\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mu0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbs_hs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniform_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_controller\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresidual_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu_rand\u001b[0m\u001b[0;34m,\u001b[0m  \u001b[0mΔt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'euler'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m         \u001b[0;31m# Optimization step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-41-c82013d4da5e>\u001b[0m in \u001b[0;36mresidual_loss\u001b[0;34m(g, X, u, Δt, method, order)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mresidual_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mΔt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0mt_span\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mΔt\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mx_fine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mControlledSystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dynamics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_span\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'dopri5'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m     \u001b[0mx_coarse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mControlledSystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dynamics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_span\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0mxfu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mControlledSystem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dynamics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mu0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/odeint.py\u001b[0m in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn)\u001b[0m\n\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mevent_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0mevent_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate_until_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/solvers.py\u001b[0m in \u001b[0;36mintegrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m     28\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_before_integrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m             \u001b[0msolution\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_advance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0msolution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py\u001b[0m in \u001b[0;36m_advance\u001b[0;34m(self, next_t)\u001b[0m\n\u001b[1;32m    192\u001b[0m         \u001b[0;32mwhile\u001b[0m \u001b[0mnext_t\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    193\u001b[0m             \u001b[0;32massert\u001b[0m \u001b[0mn_steps\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_num_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'max_num_steps exceeded ({}>={})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_num_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 194\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_adaptive_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    195\u001b[0m             \u001b[0mn_steps\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    196\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0m_interp_evaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minterp_coeff\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrk_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py\u001b[0m in \u001b[0;36m_adaptive_step\u001b[0;34m(self, rk_state)\u001b[0m\n\u001b[1;32m    253\u001b[0m         \u001b[0;31m# trigger both. (i.e. interleaving them would be wrong.)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m         \u001b[0my1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my1_error\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_runge_kutta_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtableau\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtableau\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    256\u001b[0m         \u001b[0;31m# dtypes:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    257\u001b[0m         \u001b[0;31m# y1.dtype == self.y0.dtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py\u001b[0m in \u001b[0;36m_runge_kutta_step\u001b[0;34m(func, y0, f0, t0, dt, t1, tableau)\u001b[0m\n\u001b[1;32m     74\u001b[0m             \u001b[0mperturb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPerturb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNONE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     75\u001b[0m         \u001b[0myi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta_i\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview_as\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m         \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mti\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mperturb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mperturb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     77\u001b[0m         \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_UncheckedAssign\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1104\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1105\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1106\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1107\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1108\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torchdiffeq/_impl/misc.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, y, perturb)\u001b[0m\n\u001b[1;32m    189\u001b[0m             \u001b[0;31m# Do nothing.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    190\u001b[0m             \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/amazing-mpc/src/env/quadcopter_gym.py\u001b[0m in \u001b[0;36m_dynamics\u001b[0;34m(self, t, x)\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0mthrust\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mthrust_z\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m         \u001b[0mrotation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meuler_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrpy\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrpy\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrpy\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     58\u001b[0m         \u001b[0mthrust_world_frame\u001b[0m \u001b[0;34m=\u001b[0m  \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'...ij, ...j-> ...i'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mthrust\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0mforce_world_frame\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mthrust_world_frame\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGRAVITY\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/amazing-mpc/src/env/quadcopter_gym.py\u001b[0m in \u001b[0;36meuler_matrix\u001b[0;34m(ai, aj, ak, repetition)\u001b[0m\n\u001b[1;32m     91\u001b[0m     \u001b[0;31m# Tricks to create batched matrix [...,3,3]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m     \u001b[0;31m# any suggestion to make code more readable is welcome!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m     \u001b[0mM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mai\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mrepetition\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m         \u001b[0mM\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "    \n",
    "with torch.no_grad():\n",
    "    ax.plot(losses)\n",
    "    ax.set_yscale('log')"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD4CAYAAAD4k815AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAawElEQVR4nO3deXxU1d0G8OdkmYSELIQkZAMSCGsgCAQUFBFBdsS1gq/a+loprfhqa9sXX8Aq1Ypt0YpYFZfWpSLuiMgmi4ACIQGEREjCmgVCAlkh+8x5/5ghZJJJCMlkzsyZ5/v55MPMuXfu/M7c4Zk7d+49V0gpQURE+vBQXQAREdkXg52ISDMMdiIizTDYiYg0w2AnItKMl+oCACA0NFTGxsaqLoOIyKWkpqaek1KGNW53imCPjY1FSkqK6jKIiFyKEOKUrXbuiiEi0gyDnYhIMwx2IiLNMNiJiDTDYCci0gyDnYhIMwx2IiLNuHSwV9YY8fm+XHDoYSKiy5ziBKW2Wvx1OlYm5yA8wBc39AlVXQ4RkVNw6S32lck5AID73t6juBIiIuehNNiFEDOEECtKS0tVlkFEpBWlwS6lXCOlnBMUFKSyDCIirbj0rhgiImpKm2CvM5pUl0BE5BS0CfYDOSWqSyAicgraBPui1emqSyAicgraBHvm2XLVJRAROQVtgt1o4tmnRESARsFORERmLh3sexdMUF0CEZHTcelgDwvwUV0CEZHTcelgJyKiphjsRESaYbATEWnG5YN9wdQBqksgInIqLh/sg2M4MiQRUUMuH+yhnQ2qSyAiciouH+zx4QGqSyAiciouH+xERGSNwU5EpBkGOxGRZhjsRESaYbATEWlGq2AvqahRXQIRkXJaBXtucaXqEoiIlNMq2IVQXQERkXpaBTsREWkW7ALcZCci0ivYmetERHoFOxERaRbsUqqugIhIPa2Cfd6H+1SXQESknFbBfvzcRdUlEBEpp1WwExERg52ISDsMdiIizWgR7P0jeHk8IqJLtAj2G/uGqS6BiMhpaBHsUUG+qksgInIaWgT7xIQI1SUQETkNLYLd04ODxBARXaJFsHtw9C8ionpaBDu32ImILtMj2LnFTkRUT4tg9/HWohtERHahRSL6enuqLoGIyGloEexERHQZg52ISDMMdiIizTDYiYg0w2AnItIMg52ISDMMdiIizTDYiYg0o12wm0xSdQlERErZPdiFELcJId4UQqwWQky09/KvxCgZ7ETk3loV7EKId4QQBUKItEbtk4UQGUKIo0KI+QAgpfxSSvkwgF8AuMfuFV+BkVvsROTmWrvF/m8Akxs2CCE8AbwKYAqAgQBmCyEGNphloWW6Q9Ux2InIzbUq2KWU2wEUNWoeCeColPK4lLIGwEcAZgqzFwCsk1Lua26ZQog5QogUIURKYWFhW+tvwmhksBORe2vPPvZoADkN7uda2h4FMAHAXUKIuc09WEq5QkqZJKVMCgsLa0cZ1rZmFNhtWURErsirHY+1dXULKaVcBmBZO5bbLm9sP47bhkarenoiIuXas8WeC6B7g/sxAE63r5z2q6ypU10CEZFS7Qn2vQD6CCHihBAGALMAfGWfstru5PkK1SUQESnV2sMdVwLYBaCfECJXCPGQlLIOwDwAGwAcBvCxlDK940olIqLWaNU+dinl7GbavwHwjV0rIiKidtFuSAEiInfHYCci0ozSYBdCzBBCrCgtLVVZBhGRVpQGu5RyjZRyTlBQkMoyiIi0os2umAVTB6gugYjIKWgT7F07G1SXQETkFLQJdi9PbbpCRNQu2qSht4etoWuIiNyPNsF+Q59Q1SUQETkFbYLdz9CegSqJiPShTbBzTwwRkZk2wS4Ek52ICOCZp0RE2uGZp0REmtFmVwwREZkx2ImINKNlsFfXGVWXQESkjJbB/lFyjuoSiIiU0TLY88uqVJdARKSMlsGekV+uugQiImW0DPYtRwpUl0BEpIyWwU5E5M4Y7EREmuGQAkREmuGQAkREmuGuGCIizTDYiYg0o1Wwv/izIapLICJSTqtgH9s3THUJRETKaRXsXh5adYeIqE20SsLOvrygNRGRVsHuyStaExHpFexERMRgJyLSDoOdiEgzDHYiIs1oOwhYndFk92USEbkCbQcBK6uqs/syiYhcgba7Yg7lcShgInJP2gb7tgxeHo+I3JO2wf7+rlOqSyAiUkLbYK8zSdUlEBEpoW2wExG5KwY7EZFmGOxERJphsBMRaYbBTkSkGe2CfergiPrbUvLIGCJyP9oF+y9Gx9Xfzi6qUFgJEZEa2gX7yLiQ+tu3Lv9eYSVERGpoF+wNlVbWqi6BiMjhtA52IiJ3pO147JccLSjvsGUTETkjbcdjv+SJTw522LKJiJyR9rtiThReUF0CEZFDaRnsiTGXvwHwSkpE5G60DPab+oapLoGISBktg33ezX1Ul0BEpIyWwW7wsu5WWRWPZyci96FlsDf2+EcHVJdAROQwbhHsW47wwtZE5D60DfafFk9SXQIRkRLaBrufwcvq/uoDeYoqISJyLG2DvbHHPjqAC9U8pp2I9Oc2wQ4AP+aUqC6BiKjDuVWwnzrPC28Qkf60DvbMZ6dY3ffx0rq7REQANA/2xicq1RpNiiohInIcrYO9sfmfH0JpBc9CJSK9aR/sD4zqaXV/yOKNiiohInIM7YO9W6Cv6hKIiBxK+0vjDe0R3GHLJiJyRtpfGm9079AOWzYRkTPSfleMLcUXa1SXQETUYdwi2A89PdHqfuqpYkWVEBF1PLcI9gBfb6v7v3wvRVElREQdzy2CnYjInbhNsC+cNsDqfl5JpaJKiIg6ltsE+/TEKKv7Ww6fVVQJEVHHcptgjwiyPlFp0ep0RZUQEXUstwl2IiJ34VbB/ueZCVb3f/1BKkd8JCLtuFWw3z8q1ur+urR8LN2YqaYYIqIO4lbBbsvr3x1D8oki1WUQEdmN2wc7APzsjV2qSyAishu3C/ZHb45XXQIRUYdyu2B/bHwfm+2fpOQ4uBIioo7hdsHu5Wm7y3/49KCDKyEi6hhuF+wAcH18V9UlEBF1GLcM9uE9Q1SXQETUYdwy2CcMCLfZLqV0cCVERPbnlsGeGBNss/3tnSccWwgRUQdwy2AHgNuHRjdpe3btYQWVEBHZl9sG+0v3XGOzPXb+WpRU8JqoROS63DbYW3LN4k2qSyAiajOlwS6EmCGEWFFaWqrk+ZMXjG922o85JY4rhIjIjpQGu5RyjZRyTlBQkJLnDw/wxRO39LU5bear3zu4GiIi+3D7XTE+3s2/BNxqJyJX5PbB3tKh64+u3O+4QoiI7ITB3sK07KIKh9VBRGQvbh/skxIiWpx+/9t7HFQJEZF9uH2wx4X6Y/m9QzEitovN6TuyzvEKS0TkUtw+2AFgemIUPpk7GhGBvjan8wpLRORKGOwN7Pjfcc1O23XsvAMrISJqOwZ7A97NXIQDAGa/uduBlRARtR2D/Sp8uT8PO7IKVZdBRNQiBnsjh56eCEMzW+6PrzqA+99ORuqpYgdXRUTUegz2RgJ8vZHx7OQW57nztR8cVA0R0dXzUl2AMxJCXHGe2St2IzLYFzOGRGFcP9tXZCIiUoFb7M1Y//iYFqfvOn4en+/Lw4P/2uugioiIWofB3ox+3QJaPW/m2XIUX+TFOYjIOTDYmyGEwJE/t7yv/ZKJL23HjOU7O7giIqLWYbC3wNfbs9Xz5hZXYn1afgdWQ0TUOgz2K5g9snur5537QSpi56/FyuRsZOSX48VNmR1YGRGRbUK2NCC5gyQlJcmUlBTVZbQodv7aNj/260dvwKBo66tElVbW4mBuCcb0CWtvaUTkpoQQqVLKpMbt3GJvpc9+PbrNj53+yk7cunwnlm3OQlpeKSprjHjkP/tw/9vJyHHTMd+llHCGjQoiHXGL/Sp8sPsUFn6ZZvflvnH/cMSHd0Z0cCf85ZvDGBgZiNDOPhg/INzqmPpaowmnSyoRFdypxXFtXMErm7OwdFMmflo8CX6GpqdTGE3m4PdyUD8LyqoAYb4OLpGraG6LnScoXYX7ruuJFzdlosjOhzb+6v1Um+3zxsXjne9PoKLGaHP6hAHdkBAViOmJkYjp4geDlwde2pSJjLPlKKusxT//axhC/A0wmi4H5Ib0fKzYfhx3D4/B/M8PIeu5KaioNmJfdjGC/bzRxc+A8EAfm2F7iZQSB3NL0T3ED6dLKpvsZsotrsCqvTkor6pDZJAvfjW2t9V0k0liqeX3h20ZhZg6OLLJc4xfug2niipw4vlpTaaVVNTg3IUafLgnG4umD2jVCWWX6n74vVQ8MKonbogPhQTw9cHTyCupxF/XZwAATi5p+nxXo7yqFuVVdQjs5I0zJZVYfeA0lm89iuX3DsX0xKgm81fVGrE9sxATEyKw5/h59O0WgC7+hibz5RRVYMX243h8Qh90MnjCz+CF0yWViAzyhRACOUUV2JF1Di9uykDKwlvqH1dQVoXy6jr0DuvcbM1Gk0RVrRH+Pi3Hgckk8U3aGYztG4YAX++reFU6Tq8n1+JnSd2x5M7ENj2+zmjCxRojgjrZvz/PrEnHv74/2e73VFtwi/0qbTlyFv/9b9eo1V5G9+6KSQkR2Hn0HB4b3wfTX7E+tPP1+4Zj7gepeHnWNZh5TTRuXb4TB3NL66eH+Bvw7e/G4kJVHQ7lleJIfhle2XK0fvpT0wdixpAoLNuchW2ZBZgzphcWrU4HAHz3h5vQs6s/8koqAQBbDp+tnwYAW39/Ezak5yMxJgije4fWt584dxHHCi4gt9i8q0sIgZv7h2PMX7cCAAZEBuLwmbImfX3pniFY+EUaFs8chIkJ3XCmtAqxXf1xrPACBkQGAgC+OXQGkUG+GNqjC1JPFaFboC/CAnzw4Z5sPLPmJ5uvYYi/AfsWXQ7cqlojsosq8NtVB5B+ugwLpw3As2sPIyzAB3sXTDD37UgBNqTn48Hr47Dgi0NIsYxRFBnki3d+MQJTXt6BRdMHYmiPYNzxz8vDXPy0eBIGPrUBM4ZEYc2Pp82vx/NTm3z7qzWacCC7BO/vPoV1afnYt+gWhPgb8P3RcyitrLX6wD1/oRpPfPIjtmWYB8G7e3gMJiVEIDTAB/0jAvDB7lM4XVKF5JPnseSORCzfchSPTegDP4Mn/rMnG09O6W/1/McLLyCmix8+TsnBoOggXNM9GACwMT0f/j5eGN27K9794SRmjezR5Oi0OqMJG9LPIjzQB3e/fvlaCbufHI/sogoUV9RgUkIEtmUU4Jk1P2HVnOvQxd+AhKc24Pk7BkMIYH1aPl6eNRRPrU7DJ6m5Vt8ctx4pQOGFakxPjMQX+/Nw9/Du6LtwHeaO7Y27k2KQdfYC9p4sQmllLXqG+OGXY3phX3YxXt6cheQTRTjy58lYmXz5vfC3uxLxz23H8Np9wzD5HzvQ1d+AEbEhuDspBuMHdLP5fmmt5rbYGexXKft8BW7821bVZZALGhAZiNziCggAZVV1zc53csm0K/5YP7ZvGL7L7NiRRh8Y1RPhAT74+0bnOLprzbwbUF5di3vftM/lKh8eE4c3d5wAAAzpHoz0vFLUmZrm4R8n96v/Rmdv3p4CWc9NbfPjGex2dLqkEiYpccMLDHgiap/27KrhPnY7igrupLoEIqJmufahFYodXjwZ/3pwBADAy6N1P+ARETXUEeNMcYu9HToZPDGuXziev2MwpgyKgNEk8d6uU3h5c5bq0ojIRXh62n+jkFvsdjB7ZA8E+xnQtbMPHr6xFwZFB+KWgeZfu0f16qq4OiJyZj5e9o9hBruddfbxwtePjsHjE/oAABZOH4Ctv7+pfnr6M5PqbycvGI/wAB9Hl0hETsTbg8HuMhKignByyTQkRAUhLtS/vt3fxwt+BvNxuX4GLyQvmIBbhzQ9cYWI3INHB/w+x2B3oIQo8wkuPx8dC+DyV7Bls4dixf3DkbxgPF66ZwiWzR6KpXcPwYr7h9c/9sTzto91/fw3zY9h0z/CfLEQP4MnHrw+1g49ICJXwOPYHaSipg5eHh4weHlYBsC68id1ndGEif/YjvmT+2NiQgRi569Fz65++O4P4+pPYDm5ZBpyiytQWF4Nb08PbPrpLBJjgnAorxT/c3Mfq+coLK9GQXkVeoV2xtNfpePJqf1xzeJNNp/7qekDsWT9EdTUmQAAX/xmNCKDOuG65zfj9xP7Ylz/cOQUVSLQ1wvVRhPGxIfCy9MDI5/7FgXl1bixbxi2W06gefTmePh6e2LZ5ixIAJnPTgFgPsU/7slvAACHnp6IAF9vqxNzkv9vPEb+ZbPN+ob1CMa+7BIA5jNjgzp5Y10L4+FnPjsFfReuq78/rl8YtmZYn+AzKDoQaXlNz0a15dO5o5AUG4KlGzNgkhKvbj12xcd8+7sbYfD0xO4T5zGqV1fkFFXg3reu7mSbEH9Dq4a0aHjyjS2Bvl6ICu6EJXcm4rZXv69vDw/wQUF5tdW8EwZ0w7eHz7b4fMF+3iipqL1iXbYIAUgJ/GpsL7zx3XGraW89kIRfvmfOhnuSumNVSk6bnqOhdY+NwZSXd9TfX3LHYMz//NBVL+eH+Tdj9JItNqe9cOdg/O9nV17m/kW32BxCorWc8gQlIcQMADPi4+MfzsrikSRXUlVrhIcQMHh5IPVUMQrKqjDFxjgrV6OmzoTSylqMeO7b+raWTpgor6pFZx+vZsdnkVKisLwa4YG+ePeHkxjbNwyxDXZF2Vqer7dn/aBmJpPEkvVHMGtEd/QK64yiizXw9fbA0o2ZeHunOajuvbYH/nL7YBSUV8HHyxNBnbwhpcS2jEI8+G/zNWgfvD4WT00fiMpaI3y9POHhIVBrNOH1bcfwyLh4eHgIHMgpQWxXv/pagv0M2J9djLS8UtwzogeyCsrxzaEzmJ4YhX7dAnD7az/gx5wSAMC3vxuL+PDL469IKfHWjhO4ZWA33PT3bYgK8sXp0ioAwIcPX2s13EFDVbVG9F+0vv51v/TB9tsJfXHn8Gi8svkoRsSF4Nq4EHQP8UOt0YQ+C9ZhWmIkls0aCk8PgXkf7kNpZS3ef+haPPbRfpwpqcLHc0ehus6Ii9VG+Bk88fB7KfjTjAT0CvXHd1mFuKlvWP06XJ+Wj7kfpOKRcb1xT1IPrDl4GncMi8bL32bhiYn9EBbggzqj+QP+qa/S8eGebADms1/PllVhZFwIFs8cBJNJ4kh+OcqqajEyNgQeHgJ1RhOGLt6E8uo6LJ6ZgNuHRiMtrwyz39yNeePiMa5/GIL9DNieWYjZI3vUvxa/ndAXj94cX7+MHVnnMKZPKN7aeQJL1h2pf/0CfLxQXl2HmC6dcNfwGGSfr8ALdyXij58exBf786xe68kJEVifno+TS6ahus6IE+cuorC82mro7IYbFvsX3YLtWYUYERuCTT+dxaDoIAzv2cVqo6rxe3nw0xvrp7214zjCAnwwKSHCah0DwIsbMxAV3AmzRvaw+b5oLacM9kvcYYvd2VXXGXHuQg0iAn3h6cTH5F96vzb3wXIwtwSxof4I7IBBqkwmiRqjCWl5pUiKDbni/JcCYNWc63BtC0dHNQyK93adxFOr0/GnGQPx4PVx9in8CqSU+GJ/HqYnRsFwhSM0pJS4WGPE3PdT8extg1r80G6LL/fnYWRcSIsnAcbOX4vRvbvikXHxuD4+FJU1Rnh4AD5e1mPKmEwSG9Lz0cPy4T3QMtZPS4PGVdcZISBQVlWL0M62D2wY9/dtiOnSCe8/dG2TaUaThIdo+hypp4rgIQSG9ujS7HO3BYOdyMGKL9Zg5d5s/Hps7xbD5OO9OegV5o+k2BAYTRKfpebijmHRDhuy2NVU1xnh5eHh1BsgjsJgJyLSDK+gRETkJhjsRESaYbATEWmGwU5EpBkGOxGRZhjsRESaYbATEWmGwU5EpBmnOEFJCFEI4FQbHx4K4Jwdy3EWuvYLYN9cka79Aly7bz2llGGNG50i2NtDCJFi68wrV6drvwD2zRXp2i9Az75xVwwRkWYY7EREmtEh2FeoLqCD6NovgH1zRbr2C9Cwby6/j52IiKzpsMVOREQNMNiJiDTjssEuhJgshMgQQhwVQsxXXU9rCSFOCiEOCSEOCCFSLG0hQohNQogsy79dGsz/pKWPGUKISQ3ah1uWc1QIsUy0dImejunHO0KIAiFEWoM2u/VDCOEjhFhlad8jhIhV3LenhRB5lvV2QAgx1dX6JoToLoTYKoQ4LIRIF0I8Zml36fXWQr9cfp21mZTS5f4AeAI4BqAXAAOAHwEMVF1XK2s/CSC0UdtfAcy33J4P4AXL7YGWvvkAiLP02dMyLRnAKAACwDoAUxzcjxsBDAOQ1hH9APAbAK9bbs8CsEpx354G8Hsb87pM3wBEAhhmuR0AINNSv0uvtxb65fLrrK1/rrrFPhLAUSnlcSllDYCPAMxUXFN7zATwruX2uwBua9D+kZSyWkp5AsBRACOFEJEAAqWUu6T5nfZeg8c4hJRyO4CiRs327EfDZX0KYLyjvpU007fmuEzfpJRnpJT7LLfLARwGEA0XX28t9Ks5LtGv9nDVYI8GkNPgfi5aXpHORALYKIRIFULMsbR1k1KeAcxvUgDhlvbm+hltud24XTV79qP+MVLKOgClALp2WOWtM08IcdCyq+bS7gqX7JtlV8JQAHug0Xpr1C9Ao3V2NVw12G19UrrKcZvXSymHAZgC4BEhxI0tzNtcP12t/23ph7P18TUAvQFcA+AMgKWWdpfrmxCiM4DPADwupSxraVYbbU7bNxv90madXS1XDfZcAN0b3I8BcFpRLVdFSnna8m8BgC9g3q101vI1EJZ/CyyzN9fPXMvtxu2q2bMf9Y8RQngBCELrd4/YnZTyrJTSKKU0AXgT5vUGuFjfhBDeMIfff6SUn1uaXX692eqXLuusLVw12PcC6COEiBNCGGD+MeMrxTVdkRDCXwgRcOk2gIkA0mCu/eeW2X4OYLXl9lcAZll+kY8D0AdAsuXrcrkQ4jrLfr4HGjxGJXv2o+Gy7gKwxbLfU4lLwWdxO8zrDXChvlnqeBvAYSnliw0mufR6a65fOqyzNlP9621b/wBMhfnX72MAFqiup5U194L51/gfAaRfqhvmfXWbAWRZ/g1p8JgFlj5moMGRLwCSYH6jHgOwHJaziB3Yl5Uwf72thXlr5iF79gOAL4BPYP5hKxlAL8V9ex/AIQAHYf5PHulqfQNwA8y7Dw4COGD5m+rq662Ffrn8OmvrH4cUICLSjKvuiiEiomYw2ImINMNgJyLSDIOdiEgzDHYiIs0w2ImINMNgJyLSzP8D9cpYDgjJWy4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "source": [
    "# Uncomment to save hypersolver model\n",
    "# with torch.no_grad():\n",
    "#     torch.save(hs, 'saved_models/hypersolver_'+str(Δt)+'_new_quadcopter.pt')"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.5 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "interpreter": {
   "hash": "d77f8d9122331bf0c813f643ab906d6086736a1197fa074182ffff0b1ac62f18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}