{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Models\n",
    "- Trained for 2,000,000 time steps (took 6601 seconds) using environment and procedure documented in `ad6f5817258303e2e092b4fbdb4fd0dc9356373b`\n",
    "- Trained for 100,000 time steps (took 251 seconds) using environment 3 and procedure documented in `c4b74d2e110af467a0e0af745b4d31d6675bcc44`\n",
    "- Small network (8-8) trained for 100,000 time steps (took 330 seconds) using environment 3 and procedure documented in `830eb4bf42dfbfc5b809ad59ce94335557a352f3`\n",
    "- Small network (8-8) further trained for 400,000 time steps using procedure documented in `b6631ec96d9f604dc2f69fc393a47a25bfd68505`\n",
    "- Small network (8-8) trained for 1,400,000 time steps (took 4,885 seconds) using environment 3 with modified rewards and procedure documented in `be827b0d3deb94789fd8adcc8d294afbfb015876` (evaluation over 30,000 steps claims perfect results i.e. reward 1.0+-0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OnnxableActionPolicy(torch.nn.Module):\n",
    "    def __init__(self, extractor, action_net, value_net):\n",
    "        super(OnnxableActionPolicy, self).__init__()\n",
    "        self.extractor = extractor\n",
    "        self.action_net = action_net\n",
    "        self.value_net = value_net\n",
    "        \n",
    "        normalize_linear1 = torch.nn.Linear(2, 8)\n",
    "        # ((max(0,x) - max(0,-x)) - max(0,x-1) + max(0,-x-1))\n",
    "        normalize_linear1.weight.data = torch.Tensor([\n",
    "            [1,0],[-1,0],[1,0],[-1,0],\n",
    "            [0,1],[0,-1],[0,1],[0,-1]\n",
    "        ])\n",
    "        normalize_linear1.bias.data=torch.Tensor([0,0,-1,-1,0,0,-1,-1])\n",
    "        A=1\n",
    "        normalize_linear2 = torch.nn.Linear(8,2)\n",
    "        normalize_linear2.weight.data = torch.Tensor([[A,-A,-A,A,0,0,0,0],[0,0,0,0,A,-A,-A,A]])\n",
    "        normalize_linear2.bias.data=torch.Tensor([0])\n",
    "        self.normalizer = torch.nn.Sequential(\n",
    "            normalize_linear1,\n",
    "            torch.nn.ReLU(),\n",
    "            normalize_linear2)\n",
    "\n",
    "    def forward(self, observation):\n",
    "        # NOTE: You may have to process (normalize) observation in the correct\n",
    "        #       way before using this. See `common.preprocessing.preprocess_obs`\n",
    "        action_hidden, value_hidden = self.extractor(observation)\n",
    "        action = self.action_net(action_hidden)\n",
    "        return self.normalizer(action) #, self.value_net(value_hidden)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example: model = PPO(\"MlpPolicy\", \"Pendulum-v0\")\n",
    "model = PPO.load(\"model_backup/zeppelin-avoidance-windsystem-small2-1400000-1000000-0.5\")\n",
    "model.policy.to(\"cpu\")\n",
    "onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_input = torch.randn(1, 4)\n",
    "torch.onnx.export(onnxable_model, dummy_input, \"zeppelin-avoidance-small2-1400000-retrain-1000000-0.5.onnx\", opset_version=9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Load and test with onnx\n",
    "\n",
    "import onnx\n",
    "import onnxruntime as ort\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load(\"zeppelin-avoidance-small2-1400000.onnx\")\n",
    "onnx.checker.check_model(onnx_model)\n",
    "\n",
    "observation = np.zeros((1, 7)).astype(np.float32)\n",
    "ort_sess = ort.InferenceSession(\"zeppelin-avoidance-small2-1400000.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([[1., 1.]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "print(ort_sess.run(None, {'input.1': [[72.,131.,80.,30.]]}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rename output nodes to not purely numeric names!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load('zeppelin-avoidance-small2-1400000-retrain-1000000-0.5.onnx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model.graph.output[0].name = \"out1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model.graph.node[-1].output[0]=\"out1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model.graph.node[len(onnx_model.graph.node)-1].output[0]=\"out1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'input.1'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnx_model.graph.input[0].name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx.save(onnx_model, 'zeppelin-avoidance-small2-1400000-retrain-1000000-0.5.onnx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "nnequiv-tf1",
   "language": "python",
   "name": "nnequiv-tf1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
