{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Hypersolvers for Optimal Control - Direct Optimal Control"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "import sys; sys.path.append(2*'../') # go n dirs back\n",
    "from src import *"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.datasets import *\n",
    "from torchdyn.numerics import odeint, Euler, HyperEuler"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Optimal Control\n",
    "\n",
    "We want to control an inverted pendulum and stabilize it in the upright position. The equations in Hamiltonian form describing an inverted pendulum with a torsional spring are as following:\n",
    "\n",
    "$$\\begin{equation}\n",
    "    \\begin{bmatrix} \\dot{q}\\\\ \\dot{p}\\\\ \\end{bmatrix} = \n",
    "    \\begin{bmatrix}\n",
    "    0& 1/m \\\\\n",
    "    -k& -\\beta/m\\\\\n",
    "    \\end{bmatrix}\n",
    "    \\begin{bmatrix} q\\\\ p\\\\ \\end{bmatrix} -\n",
    "    \\begin{bmatrix}\n",
    "    0\\\\\n",
    "    mgl \\sin{q}\\\\\n",
    "    \\end{bmatrix}+\n",
    "    \\begin{bmatrix}\n",
    "    0\\\\\n",
    "    1\\\\\n",
    "    \\end{bmatrix} u\n",
    "\\end{equation}$$"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "class ControlledPendulum(nn.Module):\n",
    "    \"\"\"\n",
    "    Inverted pendulum with torsional spring\n",
    "    \"\"\"\n",
    "    def __init__(self, u, m=1., k=.5, l=1., qr=0., β=.01, g=9.81):\n",
    "        super().__init__()\n",
    "        self.u = u # controller (nn.Module)\n",
    "        self.nfe = 0 # number of function evaluations\n",
    "        self.cur_f = None # current function evaluation\n",
    "        self.cur_u = None # current controller evaluation \n",
    "        self.m, self.k, self.l, self.qr, self.β, self.g = m, k, l, qr, β, g # physics\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        self.nfe += 1\n",
    "        q, p = x[..., :1], x[..., 1:]\n",
    "        self.cur_u = self.u(t, x)\n",
    "        dq = p/self.m\n",
    "        dp = -self.k*(q - self.qr) - self.m*self.g*self.l*torch.sin(q) \\\n",
    "            -self.β*p/self.m + self.cur_u\n",
    "        self.cur_f = torch.cat([dq, dp], -1)\n",
    "        return self.cur_f"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "In order to control the pendulum, we have to define a proper _integral cost function_ which will be our loss to be minimized during training. In a general form, it can be defined as:\n",
    "\n",
    "\\begin{equation}\n",
    "        J = x^\\top(t_f)\\mathbf{P} x(t_f) + \\int_{t_0}^{t_f} \\left[ x^\\top(t) \\mathbf{Q} x(t) + u^\\top(t) \\mathbf{R} u(t) \\right] dt\n",
    "\\end{equation}\n",
    "\n",
    "where $ x = \\begin{bmatrix} q\\\\ p\\\\ \\end{bmatrix}$ is the state and $\\mathbf{u}$ is the controller and matrices $\\mathbf{P},~\\mathbf{Q}, ~ \\mathbf{R}$ are weights for controlling the performance."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "class IntegralCost(nn.Module):\n",
    "    '''Integral cost function\n",
    "    Args:\n",
    "        x_star: torch.tensor, target position\n",
    "        P: float, terminal cost weights\n",
    "        Q: float, state weights\n",
    "        R: float, controller regulator weights\n",
    "    '''\n",
    "    def __init__(self, x_star, P=0, Q=1, R=0):\n",
    "        super().__init__()\n",
    "        self.x_star = x_star\n",
    "        self.P, self.Q, self.R, = P, Q, R\n",
    "        \n",
    "    def forward(self, x, u=torch.Tensor([0.])):\n",
    "        \"\"\"\n",
    "        x: trajectory\n",
    "        u: control input\n",
    "        \"\"\"\n",
    "        cost = self.P*torch.norm(x[-1] - self.x_star, p=2, dim=-1).mean()\n",
    "        cost += self.Q*torch.norm(x - self.x_star, p=2, dim=-1).mean()\n",
    "        cost += self.R*torch.norm(u - 0, p=2).mean()\n",
    "        return cost"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "# Change device according to your configuration\n",
    "# device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')\n",
    "device = torch.device('cpu') # feel free to change :)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "# The controller is a simple MLP with one hidden layer with bounded output\n",
    "class NeuralController(nn.Module):\n",
    "    def __init__(self, model, u_min=-20, u_max=20):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.u_min, self.u_max = u_min, u_max\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        x = self.model(x)\n",
    "        return torch.clamp(x, self.u_min, self.u_max)\n",
    "\n",
    "model = nn.Sequential(nn.Linear(2, 32), nn.Tanh(), nn.Linear(32, 1)).to(device)\n",
    "u = NeuralController(model) \n",
    "for p in u.model[-1].parameters(): torch.nn.init.zeros_(p)\n",
    "\n",
    "# Controlled system\n",
    "sys = ControlledPendulum(u).to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "from math import pi as π\n",
    "\n",
    "# Loss function declaration\n",
    "x_star = torch.Tensor([0., 0.]).to(device)\n",
    "cost_func = IntegralCost(x_star)\n",
    "\n",
    "# Time span\n",
    "dt = 0.2\n",
    "t0, tf = 0, 3 # initial and final time for controlling the system\n",
    "steps = int((tf - t0)/dt) + 1 # so we have a time step of 0.2s\n",
    "t_span = torch.linspace(t0, tf, steps).to(device)\n",
    "\n",
    "# Initial distribution\n",
    "x0 = π # limit of the state distribution (in rads and rads/second)\n",
    "init_dist = torch.distributions.Uniform(torch.Tensor([-x0, -x0]), torch.Tensor([x0, x0]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "# We consider the controller fixed during each solver step\n",
    "class RandConstController(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.u0 = torch.Tensor(1024, 1).uniform_(-10,10).to(device)\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        return self.u0\n",
    "    \n",
    "# Save previously learned controller\n",
    "u_no_hypersolver = sys.u\n",
    "sys.u = RandConstController() # modify controller for training"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "source": [
    "class VanillaHyperNet(nn.Module):\n",
    "    \"\"\"Simple hypernetwork for controlled systems\n",
    "    Input: current x, f and u from the controlled system\n",
    "    Output: p-th order residuals\"\"\"\n",
    "    def __init__(self, net):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        xfu = torch.cat([x, sys.cur_f, sys.cur_u], -1)\n",
    "        return self.net(xfu)\n",
    "    \n",
    "net = nn.Sequential(nn.Linear(5, 32), nn.Softplus(), nn.Linear(32, 32), nn.Tanh(), nn.Linear(32, 2))\n",
    "hypersolver = HyperEuler(VanillaHyperNet(net))\n",
    "# model = nn.DataParallel(hypersolver, device_ids=[1]) # feel free to change here according to your setup and GPU available.\n",
    "# model = model.to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "source": [
    "model = torch.load('saved_models/hs_torchdyn.pt').to(device)\n",
    "hypersolver = model.module"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Optimal Control with Hypersolvers\n",
    "We can use the trained hypernetwork to generate trajectories of the pendulum: then, we cast the same optimal control problem defined at the beginning of the notebook and train the neural controller."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "source": [
    "sys = ControlledPendulum(None)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "# Reinstantiate controller\n",
    "# net = nn.Sequential(nn.Linear(2, 64), nn.Softplus(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 1)).to(device)\n",
    "# u = NeuralController(net) \n",
    "# for p in u.model[-1].parameters(): torch.nn.init.zeros_(p)\n",
    "u = BoxConstrainedController(2, 1, constrained=True, num_layers=3, output_scaling=torch.Tensor([-5, 5]).to(device))\n",
    "sys.u = u\n",
    "\n",
    "# Time span\n",
    "t0, tf = 0, 3 # initial and final time for controlling the system\n",
    "steps = int((tf - t0)/dt) + 1 # so we have a time step of 0.2s\n",
    "t_span = torch.linspace(t0, tf, steps).to(device)\n",
    "\n",
    "# Initial distribution\n",
    "x0 = π # limit of the state distribution (in rads and rads/second)\n",
    "init_dist = torch.distributions.Uniform(torch.Tensor([-x0, -x0]), torch.Tensor([x0, x0]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "\n",
    "u_min, u_max = -5, 5 # constraints for the controller\n",
    "cost_func = IntegralCost(x_star, P=1, Q=1) # final position is more important\n",
    "\n",
    "# Hyperparameters\n",
    "lr = 3e-3\n",
    "epochs = 1000\n",
    "bs = 1024"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Euler"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "source": [
    "# Reinstantiate controller\n",
    "# sys = ControlledPendulum(u)\n",
    "u = BoxConstrainedController(2, 1, constrained=True, num_layers=3, output_scaling=torch.Tensor([-5, 5]).to(device))\n",
    "sys.u = u.to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr)\n",
    "\n",
    "# Training loop\n",
    "t0 = time.time(); losses=[]\n",
    "for e in range(epochs):\n",
    "    x0 = init_dist.sample((bs,)).to(device)\n",
    "    _, trajectory = odeint(sys, x0, t_span, solver='euler', atol=1e-7, rtol=1e-7)    \n",
    "    loss = cost_func(trajectory); losses.append(loss.detach().cpu().item())\n",
    "    loss.backward(); opt.step(); opt.zero_grad()\n",
    "    print('Loss {:.4f} , epoch {}'.format(loss.item(), e), end='\\r')\n",
    "timing = time.time() - t0; print('\\nTraining time: {:.4f} s'.format(timing))\n",
    "\n",
    "u_euler = sys.u"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "/home/botu/anaconda3/lib/python3.8/site-packages/torchdyn/numerics/odeint.py:82: UserWarning: Setting tolerances has no effect on fixed-step methods\n",
      "  warn(\"Setting tolerances has no effect on fixed-step methods\")\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Loss 5.2740 , epoch 999\n",
      "Training time: 28.8481 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### HyperEuler"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "# Reinstantiate controller\n",
    "u = BoxConstrainedController(2, 1, constrained=True, num_layers=3, output_scaling=torch.Tensor([-5, 5]).to(device))\n",
    "sys.u = u.to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr)\n",
    "\n",
    "# Training loop\n",
    "# Time should be taken with gpu, neural network inference is slower on cpu\n",
    "t0 = time.time(); losses=[]\n",
    "for e in range(epochs):\n",
    "    x0 = init_dist.sample((bs,)).to(device)\n",
    "    _, trajectory = odeint(sys, x0, t_span, solver=hypersolver)    \n",
    "    loss = cost_func(trajectory); losses.append(loss.detach().cpu().item())\n",
    "    loss.backward(); opt.step(); opt.zero_grad()\n",
    "    print('Loss {:.4f} , epoch {}'.format(loss.item(), e), end='\\r')\n",
    "timing = time.time() - t0; print('\\nTraining time: {:.4f} s'.format(timing))\n",
    "\n",
    "u_hyper = sys.u"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Loss 1.1993 , epoch 999\n",
      "Training time: 40.9267 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Midpoint"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "# Reinstantiate controller\n",
    "u = BoxConstrainedController(2, 1, constrained=True, num_layers=3, output_scaling=torch.Tensor([-5, 5]).to(device))\n",
    "sys.u = u.to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr)\n",
    "\n",
    "# Training loop\n",
    "t0 = time.time(); losses=[]\n",
    "for e in range(epochs):\n",
    "    x0 = init_dist.sample((bs,)).to(device)\n",
    "    _, trajectory = odeint(sys, x0, t_span, solver='midpoint')    \n",
    "    loss = cost_func(trajectory); losses.append(loss.detach().cpu().item())\n",
    "    loss.backward(); opt.step(); opt.zero_grad()\n",
    "    print('Loss {:.4f} , epoch {}'.format(loss.item(), e), end='\\r')\n",
    "timing = time.time() - t0; print('\\nTraining time: {:.4f} s'.format(timing))\n",
    "\n",
    "u_mp = sys.u"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Loss 1.1765 , epoch 999\n",
      "Training time: 54.4334 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### RK4"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "source": [
    "# Reinstantiate controller\n",
    "u = BoxConstrainedController(2, 1, constrained=True, num_layers=3, output_scaling=torch.Tensor([-5, 5]).to(device))\n",
    "sys.u = u.to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr)\n",
    "\n",
    "# Training loop\n",
    "t0 = time.time(); losses=[]\n",
    "for e in range(epochs):\n",
    "    x0 = init_dist.sample((bs,)).to(device)\n",
    "    _, trajectory = odeint(sys, x0, t_span, solver='rk4')    \n",
    "    loss = cost_func(trajectory); losses.append(loss.detach().cpu().item())\n",
    "    loss.backward(); opt.step(); opt.zero_grad()\n",
    "    print('Loss {:.4f} , epoch {}'.format(loss.item(), e), end='\\r')\n",
    "timing = time.time() - t0; print('\\nTraining time: {:.4f} s'.format(timing))\n",
    "\n",
    "u_rk4 = sys.u"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Loss 1.0977 , epoch 999\n",
      "Training time: 109.5863 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "source": [
    "# Saving learned controllers\n",
    "torch.save(u_euler, 'saved_models/u_euler.pt')\n",
    "torch.save(u_hyper, 'saved_models/u_hyper.pt')\n",
    "torch.save(u_mp, 'saved_models/u_mp.pt')\n",
    "torch.save(u_rk4, 'saved_models/u_rk4.pt')"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.5 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "interpreter": {
   "hash": "d77f8d9122331bf0c813f643ab906d6086736a1197fa074182ffff0b1ac62f18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}