{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gmXDVaxWe8uV"
   },
   "outputs": [],
   "source": [
    "# Copy of \"Copy_of_rr_centerline_random_wind.ipynb\"\n",
    "# Note, may need to change projection size depending on problem\n",
    "\n",
    "import scripts.algorithm as algorithm\n",
    "import scripts.obstacles as obstacles\n",
    "import scripts.racer as racer\n",
    "import scripts.algorithm as algorithm\n",
    "import scripts.obstacles as obstacles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XAtOHQyKf3KA"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from matplotlib import pyplot as plt \n",
    "import numpy as onp\n",
    "\n",
    "from tqdm import tqdm\n",
    "from IPython.display import HTML\n",
    "import importlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EmUCeLPJDKFZ"
   },
   "outputs": [],
   "source": [
    "#####################\n",
    "### HOLD CONSTANT ###\n",
    "#####################\n",
    "initial_clearance = 20\n",
    "n_gates = 50\n",
    "gate_gap = 30\n",
    "\n",
    "######################\n",
    "### VARYING PARAMS ###\n",
    "######################\n",
    "min_width = 10\n",
    "minDist=10\n",
    "\n",
    "######################\n",
    "### DERIVED PARAMS ###\n",
    "######################\n",
    "max_width = min_width\n",
    "\n",
    "\n",
    "\n",
    "env, env_state = obstacles.get_slalom(min_gate_width=min_width, max_gate_width=max_width, obstacle_radius=2, num_gates=n_gates, init_clearance=initial_clearance, height=(initial_clearance+(n_gates-1)*gate_gap), d_min=minDist)\n",
    "# env, env_state = obstacles.get_centerline(num_obstacles=30, init_clearance=12.5, height=12.5+27*50, obstacle_radius=3.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 64537,
     "status": "ok",
     "timestamp": 1659460020760,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "R6LAhsJomfL4",
    "outputId": "ca86ded2-1609-49c6-c229-5d7c996f3469",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# raise ValueError('DO NOT RUN THIS BLOCK!')\n",
    "keyVal = 31415\n",
    "kS = str(keyVal)\n",
    "# results, w = algorithm.algorithm_1(env, jax.random.PRNGKey(keyVal), env_state, eta=0.005, T=1200)\n",
    "results, w = algorithm.algorithm_1(env, jax.random.PRNGKey(keyVal), env_state, eta=2e-3, T=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(len(results))\n",
    "# print()\n",
    "# print(results[1][0])\n",
    "# print(results[1][1])\n",
    "# print(results[1][2])\n",
    "# print(results[1][3])\n",
    "# print(w.shape)\n",
    "# print(w[:,0])\n",
    "# print(w[0,:])\n",
    "'''\n",
    "print(listStates[1].arr)\n",
    "print(listStates[1].arr.shape)\n",
    "\n",
    "\n",
    "STATES = onp.zeros((4, len(listStates)))\n",
    "for k in range(len(listStates)):\n",
    "    STATES[:,k:k+1] = onp.array(listStates[k].arr)\n",
    "\n",
    "# Break loop\n",
    "listActions = [result[1] for result in results]\n",
    "ACTIONS = onp.zeros((2, len(listActions)))\n",
    "print(listActions[0].shape)\n",
    "for k in range(len(listActions)):\n",
    "    ACTIONS[:,k:k+1] = onp.array(listActions[k])\n",
    "\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "listStates = [result[2] for result in results]\n",
    "listActions = [result[1] for result in results]\n",
    "print(len(listStates))\n",
    "# print(listStates[1])\n",
    "STATES = onp.zeros((4, len(listStates)))\n",
    "ACTIONS = onp.zeros((2, len(listActions)))\n",
    "COSTS = onp.zeros(len(listStates))\n",
    "C2 = onp.zeros(int(len(listStates)/100))\n",
    "for k in range(len(listStates)):\n",
    "    STATES[:,k:k+1] = onp.array(listStates[k].arr)\n",
    "    ACTIONS[:,k:k+1] = onp.array(listActions[k])\n",
    "    COSTS[k] = 0.5*(0.05*(STATES[0, k]**2) + 0.025*(ACTIONS[0,k]**2))\n",
    "\n",
    "for k in range(len(C2)):\n",
    "    C2[k] = (1./50.)*onp.sum(COSTS[100*k:100*(k+1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "height": 295
    },
    "executionInfo": {
     "elapsed": 7668,
     "status": "ok",
     "timestamp": 1659460093175,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "pnKHom_f5mHJ",
    "outputId": "39f47270-db7d-42aa-ee00-5eb4f4acd39e",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "anim = env.render(listStates)\n",
    "anim.save('animations/basic_animation_rand'+kS+'.mp4', fps=20, extra_args=['-vcodec', 'libx264'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8z7rbRbHtVSb"
   },
   "outputs": [],
   "source": [
    "losses = [result[0] for result in results]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IQ_f2055XssX"
   },
   "outputs": [],
   "source": [
    "onp.savetxt('losses_rand'+kS+'.csv', losses, delimiter=',')\n",
    "onp.savetxt('states_rand'+kS+'.csv', STATES.T, delimiter=',')\n",
    "onp.savetxt('actions_rand'+kS+'.csv', ACTIONS.T, delimiter=',')\n",
    "onp.savetxt('costs_rand'+kS+'.csv', COSTS, delimiter=',')\n",
    "onp.savetxt('costs_rand'+kS+'_condensed.csv', C2, delimiter=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "height": 316
    },
    "executionInfo": {
     "elapsed": 384,
     "status": "ok",
     "timestamp": 1652984146519,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "D2_WwbJbPVXl",
    "outputId": "78022e9e-2344-4967-e69e-2b9202317e75"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################\n",
    "### HOLD CONSTANT ###\n",
    "#####################\n",
    "initial_clearance = 20\n",
    "n_gates = 50\n",
    "gate_gap = 30\n",
    "\n",
    "######################\n",
    "### VARYING PARAMS ###\n",
    "######################\n",
    "MIN_WIDTH = onp.arange(6, 12)\n",
    "MIN_DIST = onp.arange(0, 14, 2)\n",
    "print(MIN_WIDTH)\n",
    "print(MIN_DIST)\n",
    "print(len(MIN_WIDTH))\n",
    "print(len(MIN_DIST))\n",
    "keyVal = 166\n",
    "\n",
    "for kk in range(len(MIN_WIDTH)):\n",
    "    min_width = MIN_WIDTH[kk]\n",
    "    max_width = min_width\n",
    "    for jj in range(len(MIN_DIST)):\n",
    "        minDist = MIN_DIST[jj]\n",
    "        keyVal = keyVal + 1\n",
    "        kS = str(keyVal)\n",
    "        \n",
    "        print('Starting simulation '+str(len(MIN_DIST)*kk+jj+1)+' of '+str(len(MIN_DIST)*len(MIN_WIDTH)))\n",
    "        print('Current keyVal is: '+kS)\n",
    "        \n",
    "        env, env_state = obstacles.get_slalom(min_gate_width=min_width, max_gate_width=max_width, obstacle_radius=2, num_gates=n_gates, init_clearance=initial_clearance, height=(initial_clearance+(n_gates-1)*gate_gap), d_min=minDist)\n",
    "        \n",
    "        results, w = algorithm.algorithm_1(env, jax.random.PRNGKey(keyVal), env_state, eta=2e-3, T=1200)\n",
    "        \n",
    "        listStates = [result[2] for result in results]\n",
    "        listActions = [result[1] for result in results]\n",
    "        losses = [result[0] for result in results]\n",
    "        print(len(listStates))\n",
    "        STATES = onp.zeros((4, len(listStates)))\n",
    "        ACTIONS = onp.zeros((2, len(listActions)))\n",
    "        COSTS = onp.zeros(len(listStates))\n",
    "        C2 = onp.zeros(int(len(listStates)/100))\n",
    "        for k in range(len(listStates)):\n",
    "            STATES[:,k:k+1] = onp.array(listStates[k].arr)\n",
    "            ACTIONS[:,k:k+1] = onp.array(listActions[k])\n",
    "            COSTS[k] = 0.5*(0.05*(STATES[0, k]**2) + 0.025*(ACTIONS[0,k]**2))\n",
    "\n",
    "        for k in range(len(C2)):\n",
    "            C2[k] = (1./50.)*onp.sum(COSTS[100*k:100*(k+1)])\n",
    "            \n",
    "        anim = env.render(listStates)\n",
    "        anim.save('animations/basic_animation_rand'+kS+'.mp4', fps=20, extra_args=['-vcodec', 'libx264'])\n",
    "        \n",
    "        onp.savetxt('losses_rand'+kS+'_'+str(min_width)+'_'+str(minDist)+'.csv', losses, delimiter=',')\n",
    "        onp.savetxt('states_rand'+kS+'_'+str(min_width)+'_'+str(minDist)+'.csv', STATES.T, delimiter=',')\n",
    "        onp.savetxt('actions_rand'+kS+'_'+str(min_width)+'_'+str(minDist)+'.csv', ACTIONS.T, delimiter=',')\n",
    "        onp.savetxt('costs_rand'+kS+'_'+str(min_width)+'_'+str(minDist)+'.csv', COSTS, delimiter=',')\n",
    "        onp.savetxt('costs_rand'+kS+'_'+str(min_width)+'_'+str(minDist)+'_condensed.csv', C2, delimiter=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "height": 309
    },
    "executionInfo": {
     "elapsed": 19629,
     "status": "ok",
     "timestamp": 1659460131262,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "Uqr_FNL0glGn",
    "outputId": "ad2e01bd-91bf-4cf5-aa05-267e98b56d23"
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "HTML(anim.to_html5_video())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = (0.5*onp.arange(len(losses)))%50.-12.5\n",
    "f = plt.figure(figsize=(6,4))\n",
    "plt.plot(0.5*onp.arange(len(losses)), losses)\n",
    "plt.plot(0.5*onp.arange(len(losses)), onp.minimum(onp.absolute(tmp), 50.-onp.absolute(tmp))-3., 'k--')\n",
    "plt.plot((0, 250), (3.0, 3.0), 'r--')\n",
    "plt.ylabel(\"min distance to obstacles (m)\", fontsize=12)\n",
    "plt.xlabel(\"Y-distance (m)\", fontsize=12)\n",
    "#plt.xticks(np.arange(0, 101, step =10))  # Set label locations.\n",
    "plt.yticks(onp.arange(0, 25, step=5))  # Set label locations.\n",
    "#plt.xlim(0, 100)\n",
    "plt.ylim(-1, 25.5)\n",
    "plt.title(\"(b) centerline\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f2 = plt.figure(figsize=(6,4))\n",
    "plt.plot(0.5*onp.arange(len(losses)), w[0,1:], 'g')\n",
    "plt.yticks(onp.arange(-5, 6)/10)\n",
    "plt.ylim([-.51, 0.51])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "height": 91
    },
    "executionInfo": {
     "elapsed": 20633,
     "status": "ok",
     "timestamp": 1652984171840,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "P4BfUsydPd-y",
    "outputId": "1fa24834-210b-49c2-8a60-0cadd4c534ff"
   },
   "outputs": [],
   "source": [
    "f.savefig(\"foo.pdf\", bbox_inches='tight')\n",
    "!ls -l foo.pdf\n",
    "from colabtools import drive\n",
    "file_id = drive.SaveFile('random-centerline-wind.pdf', open('foo.pdf', 'rb').read(), mime_type='image/pdf')\n",
    "file_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 138,
     "status": "ok",
     "timestamp": 1652984173382,
     "user": {
      "displayName": "Luna Xia",
      "userId": "11036929073323501801"
     },
     "user_tz": 240
    },
    "id": "44E0CjpvSuK7",
    "outputId": "fa471f58-c4a6-45ea-ecf1-3c4b08285941"
   },
   "outputs": [],
   "source": [
    "foobar = jnp.zeros(10)\n",
    "foobar = foobar.at[3].set(1)\n",
    "print(foobar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "S1 = onp.loadtxt('v1/states_sin105.csv', delimiter=',')\n",
    "print(S1.shape)\n",
    "A1 = onp.loadtxt('v1/actions_sin105.csv', delimiter=',')\n",
    "print(A1.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "COSTS1 = onp.zeros(4000)\n",
    "C1 = onp.zeros(40)\n",
    "for k in range(4000):\n",
    "    COSTS1[k] = 0.5*(0.05*(S1[0, k]**2) + 0.025*(A1[0,k]**2))\n",
    "\n",
    "for k in range(40):\n",
    "    C1[k] = (1./50.)*onp.sum(COSTS1[100*k:100*(k+1)])\n",
    "\n",
    "# Done with for loop\n",
    "onp.savetxt('states_rand.csv', S1.T, delimiter=',')\n",
    "onp.savetxt('actions_rand.csv', A1.T, delimiter=',')\n",
    "onp.savetxt('costs_rand.csv', COSTS1, delimiter=',')\n",
    "onp.savetxt('costs_rand_condensed.csv', C1, delimiter=',')\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "XX = S1[:,0]\n",
    "YY = S1[:,1]\n",
    "plt.figure(3)\n",
    "kk=12\n",
    "for k in range(2, 10):\n",
    "    if k != kk:\n",
    "        yy = YY[100*k:100*(k+1)-1] % 50.\n",
    "        plt.plot(XX[100*k:100*(k+1)-1], yy, 'r')\n",
    "plt.plot(3.0*onp.cos(onp.arange(360)*3.1415926/180.), 12.5+3.0*onp.sin(onp.arange(360)*3.1415926/180.), 'k')\n",
    "plt.plot((0, 0), (0, 50), 'k--')\n",
    "plt.title('Trajectories for Online Planner [Sin]', fontsize=18)\n",
    "plt.xlabel('X (m)', fontsize=14)\n",
    "plt.ylabel('Y (m)', fontsize=14)\n",
    "plt.ylim([0, 50])\n",
    "plt.xlim([-10., 10.])\n",
    "plt.savefig('OnlineTraj_Sin105.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3",
    "kind": "private"
   },
   "provenance": [
    {
     "file_id": "1x_t0kp04Y1FANAl3R4NeuUUyUyDdAloD",
     "timestamp": 1652923414869
    },
    {
     "file_id": "1qRFbjEl8kfPrjw_-Yn5Z2ZBNMGEqsp0u",
     "timestamp": 1652718659978
    },
    {
     "file_id": "1DyM43vFyAMTA8Yrw-PGoJAg5itY3kS17",
     "timestamp": 1652462998849
    },
    {
     "file_id": "1oK4m1Tx6IgDbxR4YFKXITM8osokyDuZs",
     "timestamp": 1644253893925
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
