{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Run MPC with pre-trained `Multi-Stage HyperEuler`"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import sys; sys.path.append(2*'../') # go n dirs back to import\n",
    "from src import *"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.datasets import *\n",
    "from torchdyn.numerics import odeint, Euler, HyperEuler"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# We modify each single weight instead of P, Q, R although it would be the same\n",
    "class IntegralCost(nn.Module):\n",
    "    '''Integral cost function\n",
    "    Args:\n",
    "        x_star: torch.tensor, target position\n",
    "        P: float, terminal cost weights\n",
    "        Q: float, state weights\n",
    "        R: float, controller regulator weights\n",
    "    '''\n",
    "    def __init__(self, x_star, P=0, Q=1, R=0):\n",
    "        super().__init__()\n",
    "        self.x_star = x_star\n",
    "        self.P, self.Q, self.R, = P, Q, R\n",
    "        \n",
    "    def forward(self, x, u=torch.Tensor([0.])):\n",
    "        \"\"\"\n",
    "        x: trajectory\n",
    "        u: control input\n",
    "        \"\"\"\n",
    "        cost = 0.2*torch.norm(x[..., -1, :] - self.x_star, p=2).mean()\n",
    "        cost += torch.norm(x[..., -1, 2] - self.x_star[2], p=2).mean()\n",
    "        cost += 0.001*torch.norm(x - self.x_star, p=2).mean() # regulator\n",
    "        cost += 0.2*torch.norm(x[..., 2] - self.x_star[2]).mean()\n",
    "        cost += 0.2*torch.norm(x[..., 0] - self.x_star[0]).mean()\n",
    "        return cost"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Change device according to your configuration\n",
    "# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
    "device=torch.device('cpu')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from math import pi as π\n",
    "\n",
    "# Loss function declaration\n",
    "x_star = torch.Tensor([0. , 0.,  0., 0.]).to(device)\n",
    "cost_func = IntegralCost(x_star)\n",
    "\n",
    "# Time span\n",
    "dt = 0.01\n",
    "t0, tf = 0, 3 # initial and final time for controlling the system\n",
    "steps = int((tf - t0)/dt) + 1 # so we have a time step of 0.2s\n",
    "t_span = torch.linspace(t0, tf, steps).to(device)\n",
    "\n",
    "# Initial distribution\n",
    "x0 = π # limit of the state distribution (in rads and rads/second)\n",
    "eps = 0.1\n",
    "init = torch.Tensor([-2, 0, -x0, 0])\n",
    "var = torch.Tensor([.05, 1e-5, .05, 1e-5])\n",
    "init_dist = torch.distributions.Uniform(init-var, init+var)\n",
    "# init_dist = torch.distributions.Uniform(torch.Tensor([-eps, -eps, -3/2*x0, -eps]), torch.Tensor([eps, eps, -1/2*x0, -0]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "class VanillaHyperNet(nn.Module):\n",
    "    \"\"\"Simple hypernetwork for controlled systems\n",
    "    Input: current x, f and u from the controlled system\n",
    "    Output: p-th order residuals\"\"\"\n",
    "    def __init__(self, net):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        xfu = torch.cat([x, sys.cur_f, sys.cur_u], -1)\n",
    "#         print(xfu.shape, x.shape, sys.cur_f.shape)\n",
    "        return self.net(xfu)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Common MPC simulation variables\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "const_u = RandConstController([1, 1], -1, 1).to(device) # dummy constant controller for simulation\n",
    "real_system = CartPole(const_u, solver='tsit5')\n",
    "real_system.masspole=1.0 # actual parameter of the pole mass\n",
    "\n",
    "steps_nom = 5 # Nominal steps to do between each MPC step\n",
    "max_iters = 20\n",
    "eps_accept = 1e-3\n",
    "lookahead_steps = int(1.5/dt) # 2 seconds ahead\n",
    "bs = 32 # we use batched training so to see more initial conditions\n",
    "lr =3e-3 # adjust learning rate for avoiding \"underflow in dt nan\" from torchdiffeq\n",
    "weight_decay = 0\n",
    "\n",
    "#OVERRIDE\n",
    "# x0 = torch.Tensor([-2, 0, -π, 0])\n",
    "x0 = init_dist.sample((bs,)).to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Run MPCs\n",
    "Here we run the experiments with different solvers and save data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "u = BoxConstrainedController(4, 1, num_layers=3, constrained=True, output_scaling=torch.Tensor([-10, 10])).to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr, weight_decay=weight_decay) # optimizer\n",
    "\n",
    "sys = CartPole(u, solver='euler').to(device) # 'wrong' system\n",
    "hypersolver = torch.load('saved_models/hs_multistage.pt').to(device)\n",
    "sys.hypersolve = hypersolver.module.hypernet\n",
    "sys.multistage = True\n",
    "\n",
    "mpc = TorchMPC(sys, cost_func, t_span, opt, eps_accept=eps_accept, max_g_iters=max_iters,\n",
    "            lookahead_steps=lookahead_steps, lower_bounds=None,\n",
    "            upper_bounds=None, penalties=None).to(device)\n",
    "\n",
    "loss_mpc = mpc.forward_simulation(real_system, x0, t_span)\n",
    "\n",
    "torch.save(mpc.trajectory_nominal.cpu(), 'data/multihyper_traj.pt')\n",
    "torch.save(mpc.control_inputs.cpu(), 'data/multihyper_controls.pt')\n",
    "torch.save(loss_mpc, 'data/multihyper_loss.pt')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "u = BoxConstrainedController(4, 1, num_layers=3, constrained=True, output_scaling=torch.Tensor([-10, 10])).to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr, weight_decay=weight_decay) # optimizer\n",
    "\n",
    "sys = CartPole(u, solver='euler').to(device) # 'wrong' system\n",
    "sys.masspole = 1.0\n",
    "\n",
    "mpc = TorchMPC(sys, cost_func, t_span, opt, eps_accept=eps_accept, max_g_iters=max_iters,\n",
    "            lookahead_steps=lookahead_steps, lower_bounds=None,\n",
    "            upper_bounds=None, penalties=None).to(device)\n",
    "\n",
    "loss_mpc = mpc.forward_simulation(real_system, x0, t_span)\n",
    "\n",
    "\n",
    "torch.save(mpc.trajectory_nominal.cpu(), 'data/euler_traj.pt')\n",
    "torch.save(mpc.control_inputs.cpu(), 'data/euler_controls.pt')\n",
    "torch.save(loss_mpc, 'data/euler_loss.pt')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "u = BoxConstrainedController(4, 1, num_layers=3, constrained=True, output_scaling=torch.Tensor([-10, 10])).to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr, weight_decay=weight_decay) # optimizer\n",
    "\n",
    "sys = CartPole(u, solver='tsit5').to(device) # 'wrong' system\n",
    "# sys.masspole=1.0 # actual parameter of the pole mass\n",
    "\n",
    "mpc = TorchMPC(sys, cost_func, t_span, opt, eps_accept=eps_accept, max_g_iters=max_iters,\n",
    "            lookahead_steps=lookahead_steps, lower_bounds=None,\n",
    "            upper_bounds=None, penalties=None).to(device)\n",
    "\n",
    "loss_mpc = mpc.forward_simulation(real_system, x0, t_span)\n",
    "\n",
    "torch.save(mpc.trajectory_nominal.cpu(), 'data/tsit_fake_traj.pt')\n",
    "torch.save(mpc.control_inputs.cpu(), 'data/tsit_fake_controls.pt')\n",
    "torch.save(loss_mpc, 'data/tsit_fake_loss.pt')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "u = BoxConstrainedController(4, 1, num_layers=3, constrained=True, output_scaling=torch.Tensor([-10, 10])).to(device)\n",
    "opt = torch.optim.Adam(u.parameters(), lr=lr, weight_decay=weight_decay) # optimizer\n",
    "\n",
    "sys = CartPole(u, solver='tsit5').to(device) # 'wrong' system\n",
    "sys.masspole=1.0 # actual parameter of the pole mass\n",
    "\n",
    "mpc = TorchMPC(sys, cost_func, t_span, opt, eps_accept=eps_accept, max_g_iters=max_iters,\n",
    "            lookahead_steps=lookahead_steps, lower_bounds=None,\n",
    "            upper_bounds=None, penalties=None).to(device)\n",
    "\n",
    "loss_mpc = mpc.forward_simulation(real_system, x0, t_span)\n",
    "\n",
    "torch.save(mpc.trajectory_nominal.cpu(), 'data/tsit_true_traj.pt')\n",
    "torch.save(mpc.control_inputs.cpu(), 'data/tsit_true_controls.pt')\n",
    "torch.save(loss_mpc, 'data/tsit_fake_loss.pt')\n",
    "print(sys.nfe)"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.5 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "interpreter": {
   "hash": "d77f8d9122331bf0c813f643ab906d6086736a1197fa074182ffff0b1ac62f18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}