{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6567e296",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "#import matplotlib.pyplot as plt\n",
    "#from matplotlib.patches import Rectangle,Polygon\n",
    "##### Load and test with onnx\n",
    "import onnx\n",
    "import onnxruntime as ort\n",
    "import numpy as np\n",
    "\n",
    "observation = np.zeros((1, 2)).astype(np.float32)\n",
    "def simulate_run(start_state, file, n=3000):\n",
    "    ort_sess = ort.InferenceSession(file)\n",
    "    state = start_state\n",
    "    print(f\"Starting with: {state}\")\n",
    "    rPos = [state[0]]\n",
    "    rVel = [state[1]]\n",
    "    actions = []\n",
    "    for i in range(0,n):\n",
    "        action = ort_sess.run(None, {'input.1': [[state[0][0],state[1][0]]]})[0][0]\n",
    "        #print(f\"Action: {action}\")\n",
    "        #action=100\n",
    "        actions.append(action)\n",
    "        t=0.1\n",
    "        pos_0 = state[0][0]\n",
    "        vel_0 = state[1][0]\n",
    "        vel = action*t + vel_0\n",
    "        pos = action*t**2/2 + vel_0*t + pos_0\n",
    "        state[0]=pos\n",
    "        state[1]=vel\n",
    "        #print(f\"Next: {pos}, {vel}\")\n",
    "        rPos.append(state[0])\n",
    "        rVel.append(state[1])\n",
    "        if state[0] <= 0 or state[0] > 100:\n",
    "            if state[0] <= 1:\n",
    "                print(\"CRASH\")\n",
    "            break\n",
    "    return rPos, rVel, actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "dc770116",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[99.99999]\n",
      "[False]\n"
     ]
    }
   ],
   "source": [
    "ort_sess = ort.InferenceSession(\"../../test/networks/acc-2000000-64-64-64-64.onnx\")\n",
    "pos = 74.999997\n",
    "vel = -np.sqrt(pos*200)\n",
    "action = ort_sess.run(None, {'input.1': [[pos,vel]]})[0][0]\n",
    "print(action)\n",
    "next_pos = pos + vel*0.1 + action*0.1**2/2\n",
    "next_vel = vel + action*0.1\n",
    "print(next_pos>=next_vel**2/200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "18d9da88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting with: [[8], [-40.0]]\n",
      "CRASH\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([[8],\n",
       "  array([4.5], dtype=float32),\n",
       "  array([1.8934178], dtype=float32),\n",
       "  array([-0.0156281], dtype=float32)],\n",
       " [[-40.0],\n",
       "  array([-30.], dtype=float32),\n",
       "  array([-22.131645], dtype=float32),\n",
       "  array([-16.049273], dtype=float32)],\n",
       " [array([100.], dtype=float32),\n",
       "  array([78.683556], dtype=float32),\n",
       "  array([60.823727], dtype=float32)])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos = 8\n",
    "simulate_run([[pos], [-np.sqrt(pos*200)]],\"../../test/networks/acc-2000000-64-64-64-64.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a13225a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (nnv)",
   "language": "python",
   "name": "nnv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
