{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Multistage Hypersolver pre-train\n",
    "We consider two version of the cartpole:\n",
    "1. Known model with pole mass = 0.1kg\n",
    "2. Real model with a much bigger mass of the pole (1kg)\n",
    "\n",
    "We want to train the multi-stage HyperEuler model to see how it will perform"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import sys; sys.path.append(2*'../') # go n dirs back\n",
    "from src import *"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torchdyn.core import NeuralODE\n",
    "from torchdyn.datasets import *\n",
    "from torchdyn.numerics import odeint, Euler, HyperEuler\n",
    "\n",
    "device = 'cpu' # feel free to change!"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# The controller is a simple MLP with one hidden layer with bounded output\n",
    "class NeuralController(nn.Module):\n",
    "    def __init__(self, model, u_min=-20, u_max=20):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.u_min, self.u_max = u_min, u_max\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        x = self.model(x)\n",
    "        return torch.clamp(x, self.u_min, self.u_max)\n",
    "\n",
    "model = nn.Sequential(nn.Linear(4, 32), nn.Tanh(), nn.Linear(32, 1)).to(device)\n",
    "u = NeuralController(model) \n",
    "for p in u.model[-1].parameters(): torch.nn.init.zeros_(p)\n",
    "\n",
    "# Controlled system\n",
    "sys = CartPole(u).to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Generate distribution\n",
    "\n",
    "We generate a distribution $\\xi(x,u)$ to train the hypersolver"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from math import pi as π\n",
    "\n",
    "# Loss function declaration\n",
    "x_star = torch.Tensor([0. , 0.,  0., 0.]).to(device)\n",
    "cost_func = IntegralCost(x_star)\n",
    "\n",
    "# Time span\n",
    "dt = 0.01\n",
    "t0, tf = 0, 5 # initial and final time for controlling the system\n",
    "steps = int((tf - t0)/dt) + 1 # so we have a time step of 0.2s\n",
    "t_span = torch.linspace(t0, tf, steps).to(device)\n",
    "\n",
    "# Initial distribution\n",
    "x0 = π # limit of the state distribution (in rads and rads/second)\n",
    "init_dist = torch.distributions.Uniform(torch.Tensor([-x0, -x0, -x0, -x0]), torch.Tensor([x0, x0, x0, x0]))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# We consider the controller fixed during each solver step\n",
    "class RandConstController(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.u0 = torch.Tensor(1024, 1).uniform_(-10,10).to(device)\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        return self.u0\n",
    "    \n",
    "# Save previously learned controller\n",
    "u_no_hypersolver = sys.u\n",
    "sys.u = RandConstController() # modify controller for training"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "class VanillaHyperNet(nn.Module):\n",
    "    \"\"\"Simple hypernetwork for controlled systems\n",
    "    Input: current x, f and u from the controlled system\n",
    "    Output: p-th order residuals\"\"\"\n",
    "    def __init__(self, net):\n",
    "        super().__init__()\n",
    "        self.net = net\n",
    "        \n",
    "    def forward(self, t, x):\n",
    "        xfu = torch.cat([x, sys.cur_f, sys.cur_u], -1)\n",
    "#         print(xfu.shape, x.shape, sys.cur_f.shape)\n",
    "        return self.net(xfu)\n",
    "    \n",
    "net = nn.Sequential(nn.Linear(9, 32), nn.Softplus(), nn.Linear(32, 32), nn.Softplus(), nn.Linear(32, 32), nn.Softplus(), nn.Linear(32, 8))\n",
    "hypersolver = HyperEuler(VanillaHyperNet(net))\n",
    "model = nn.DataParallel(hypersolver) # feel free to change here according to your setup and GPU available.\n",
    "model = model.to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Initial distribution\n",
    "x0 = 2*π # limit of the state distribution (in rads and rads/second)\n",
    "init_dist = torch.distributions.Uniform(torch.Tensor([-x0, -x0, -x0, -x0]), torch.Tensor([x0, x0, x0, x0]))\n",
    "\n",
    "base_solver = Euler()\n",
    "# Time span\n",
    "t0, tf = 0, 2 # initial and final time for controlling the system\n",
    "steps = int((tf - t0)/dt) + 1 # so we have a time step of 0.2s\n",
    "t_span = torch.linspace(t0, tf, steps).to(device)\n",
    "dt = (t_span[1] - t_span[0]).detach().cpu().item()\n",
    "print(init_dist.sample((10,)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "sys_nominal = CartPole(RandConstController())\n",
    "sys_nominal.masspole = 1 # change the pole mass of the nominal controller\n",
    " "
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training loop\n",
    "We train via stochastic exploration\n",
    "(This will take some time)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "opt = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "loss_func = nn.MSELoss()\n",
    "epochs = 1000000\n",
    "bs = 2048\n",
    "hypernet = model.module.hypernet\n",
    "dt = 0.01\n",
    "span = torch.linspace(0, dt, 2)\n",
    "losses = []\n",
    "# sys._use_xfu=True\n",
    "\n",
    "for i in range(epochs):\n",
    "    # Sample random intial states and controllers\n",
    "    x0 = init_dist.sample((bs,)).to(device)\n",
    "    val = torch.Tensor(bs, 1).uniform_(-10, 10).to(device)\n",
    "    sys.u.u0 = val\n",
    "    sys_nominal.u.u0 = val\n",
    "    \n",
    "    # Compute residuals\n",
    "    _, sol_gt = odeint(sys_nominal._dynamics, x0, span, solver='tsit5', atol=1e-6, rtol=1e-6)[-1]\n",
    "    f = sys._dynamics(0, x0)\n",
    "    sys.cur_f = f\n",
    "    g = hypernet(0, x0)\n",
    "    half_state_dim = g.shape[-1] // 2\n",
    "    g1, g2 = g[..., :half_state_dim], g[..., half_state_dim:]\n",
    "    sol = x0 + dt* (f + g1) + dt**2*g2\n",
    "    loss = loss_func(sol_gt, sol)\n",
    "\n",
    "    # Optimization step\n",
    "    loss.backward(); opt.step(); opt.zero_grad()\n",
    "    print(f'Step: {i}, Residual loss: {loss:.8f}', end='\\r')\n",
    "    losses.append(loss.detach().cpu().item())"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "fig, ax = plt.subplots(1, 1)\n",
    "ax.plot(losses)\n",
    "ax.set_yscale('log')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Save the model\n",
    "torch.save(model, 'saved_models/hs_multistage.pt')"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d77f8d9122331bf0c813f643ab906d6086736a1197fa074182ffff0b1ac62f18"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.5 64-bit ('base': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}