{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OnnxableActionPolicy(torch.nn.Module):\n",
    "    def __init__(self, extractor, action_net, value_net):\n",
    "        super(OnnxableActionPolicy, self).__init__()\n",
    "        self.extractor = extractor\n",
    "        self.action_net = action_net\n",
    "        self.value_net = value_net\n",
    "        normalize_linear0 = torch.nn.Linear(1, 1)\n",
    "        normalize_linear0.weight.data = torch.Tensor([[1]])\n",
    "        normalize_linear0.bias.data=torch.Tensor([0.05])\n",
    "        normalize_linear1 = torch.nn.Linear(1, 4)\n",
    "        # 100* ((max(0,x) - max(0,-x)) - max(0,x-1) + max(0,-x-1))\n",
    "        normalize_linear1.weight.data = torch.Tensor([[1],[-1],[1],[-1]])\n",
    "        normalize_linear1.bias.data=torch.Tensor([0,0,-1,-1])\n",
    "        #print(normalize_linear1.weight)\n",
    "        #print(normalize_linear1.bias)\n",
    "        A = 100\n",
    "        normalize_linear2 = torch.nn.Linear(3,1)\n",
    "        normalize_linear2.weight.data = torch.Tensor([[A+1e-3,-A+3e-3,-A-1e-3,A-3e-3]])\n",
    "        normalize_linear2.bias.data=torch.Tensor([0])\n",
    "        self.normalizer = torch.nn.Sequential(\n",
    "            normalize_linear0,\n",
    "            normalize_linear1,\n",
    "            torch.nn.ReLU(),\n",
    "            normalize_linear2)\n",
    "\n",
    "    def forward(self, observation):\n",
    "        # NOTE: You may have to process (normalize) observation in the correct\n",
    "        #       way before using this. See `common.preprocessing.preprocess_obs`\n",
    "        action_hidden, value_hidden = self.extractor(observation)\n",
    "        action = self.action_net(action_hidden)\n",
    "        return self.normalizer(action) #, self.value_net(value_hidden)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example: model = PPO(\"MlpPolicy\", \"Pendulum-v0\")\n",
    "model = PPO.load(\"model_backup/acc-2000000-64-64-64-64-100000-200000-0.9\")\n",
    "model.policy.to(\"cpu\")\n",
    "onnxable_model = OnnxableActionPolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([100.0010], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnxable_model.normalizer(torch.Tensor([2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-99.9970], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnxable_model.normalizer(torch.Tensor([-2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_input = torch.randn(1, 2)\n",
    "torch.onnx.export(onnxable_model, dummy_input, \"acc-2000000-64-64-64-64-100000-200000-0.9.onnx\", opset_version=9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Load and test with onnx\n",
    "\n",
    "import onnx\n",
    "import onnxruntime as ort\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load(\"acc-2000000-64-64-64-64-100000-200000-0.9.onnx\")\n",
    "onnx.checker.check_model(onnx_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load(\"acc-2000000-64-64-64-64-100000-200000-0.9.onnx\")\n",
    "onnx.checker.check_model(onnx_model)\n",
    "\n",
    "observation = np.zeros((1, 2)).astype(np.float32)\n",
    "ort_sess = ort.InferenceSession(\"acc-2000000-64-64-64-64-100000-200000-0.9.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[array([[100.001]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "print(ort_sess.run(None, {'input.1': [[0.08,-4]]}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steuber/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/gym/logger.py:34: UserWarning: \u001b[33mWARN: Environment '<class 'acc.ACCEnv2'>' has deprecated methods '_step' and '_reset' rather than 'step' and 'reset'. Compatibility code invoked. Set _gym_disable_underscore_compat = True to disable this behavior.\u001b[0m\n",
      "  warnings.warn(colorize(\"%s: %s\" % (\"WARN\", msg % args), \"yellow\"))\n"
     ]
    }
   ],
   "source": [
    "env = gym.make(\"acc-variant-v1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f90d40769b0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.seed(2022)\n",
    "torch.manual_seed(2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.776630715521794\n",
      ".28.775768132958248\n",
      ".28.72983095656857\n",
      "."
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-dd1d2c36eab8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m         \u001b[0mobs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdones\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m         \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mdones\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/gym/core.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode, **kwargs)\u001b[0m\n\u001b[1;32m    252\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    253\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"human\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    256\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/gym/envs/registration.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(mode)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    210\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_render\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    213\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Dokumente/Hortus/Univ/KIT/Master/Work/repos/SNNT/experiments/acc/training/acc.py\u001b[0m in \u001b[0;36m_render\u001b[0;34m(self, mode, close)\u001b[0m\n\u001b[1;32m    540\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcarttrans2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_translation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mleaderx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcarty\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    541\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 542\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreturn_rgb_array\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'rgb_array'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    543\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    544\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mACCEnv3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mACCEnv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, return_rgb_array)\u001b[0m\n\u001b[1;32m    121\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_rgb_array\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    122\u001b[0m         \u001b[0mglClearColor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    124\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswitch_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    125\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/pyglet/window/__init__.py\u001b[0m in \u001b[0;36mclear\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1346\u001b[0m         \u001b[0mbuffer\u001b[0m\u001b[0;34m.\u001b[0m  \u001b[0mThe\u001b[0m \u001b[0mwindow\u001b[0m \u001b[0mmust\u001b[0m \u001b[0mbe\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mactive\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msee\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mswitch_to\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1347\u001b[0m         \"\"\"\n\u001b[0;32m-> 1348\u001b[0;31m         \u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglClear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGL_COLOR_BUFFER_BIT\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0mgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGL_DEPTH_BUFFER_BIT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1349\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1350\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mdispatch_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/nnequiv-tf1/lib/python3.7/site-packages/pyglet/gl/lib.py\u001b[0m in \u001b[0;36merrcheck\u001b[0;34m(result, func, arguments)\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0merrcheck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marguments\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_debug_gl_trace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     89\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(10000):\n",
    "    obs = env.reset()\n",
    "    starting_state = obs\n",
    "    #obs = [0.08,-4]\n",
    "    env.unwrapped.state = obs\n",
    "    for i in range(0,500):\n",
    "        action = ort_sess.run(None, {'input.1': [[obs[0],obs[1]]]})[0][0]/100.0\n",
    "        action = np.clip(action,-1.,1.)\n",
    "        obs, rewards, dones, info = env.step(action)\n",
    "        env.render()\n",
    "        if dones:\n",
    "            print(obs[0])\n",
    "            print(\".\",end=\"\")\n",
    "            if obs[0]<=0:\n",
    "                print(\"\\nEncountered unsafe behaviour!\")\n",
    "                print(\"Starting state: \",starting_state)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rename output nodes to not purely numeric names!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load('acc-2000000-64-64-64-64-100000-200000-0.9.onnx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model.graph.output[0].name = \"out1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model.graph.node[len(onnx_model.graph.node)-1].output[0]=\"out1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'input.1'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onnx_model.graph.input[0].name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx.save(onnx_model, 'acc-2000000-64-64-64-64-100000-200000-0.9.onnx')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "@webio": {
   "lastCommId": null,
   "lastKernelId": null
  },
  "kernelspec": {
   "display_name": "nnequiv-tf1",
   "language": "python",
   "name": "nnequiv-tf1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
