{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to Julia file by jupytext --to jl RocketLander.ipynb\n",
    "\n",
    "include(\"../src/DSI.jl\")\n",
    "include(\"../src/Zono_utils.jl\")\n",
    "include(\"../src/PZono.jl\")\n",
    "include(\"../src/DSZ.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Networks converted to nnet format from https://github.com/V2A2/StarV/tree/main/StarV/util/data/nets/DRL \n",
    "rocket_lander_nnet0 = read_nnet(\"./RocketLander_networks/unsafe_agent0.nnet\", last_layer_activation = Id());\n",
    "rocket_lander_nnet1 = read_nnet(\"./RocketLander_networks/unsafe_agent1.nnet\", last_layer_activation = Id());\n",
    "rocket_lander_nnet2 = read_nnet(\"./RocketLander_networks/unsafe_agent2.nnet\", last_layer_activation = Id());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input Data from https://github.com/V2A2/StarV/blob/main/StarV/util/load.py#L168\n",
    "\n",
    "init_lb_rocket_1 = [−0.2, 0.02, -0.5, -1., -20*pi/180, -0.2, 0.0, 0.0, 0.0, -1.0, -15*pi/180]\n",
    "init_ub_rocket_1 = [0.2, 0.5, 0.5, 1., -6*pi/180, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]\n",
    "rocket_input_1 = interval.(init_lb_rocket_1,init_ub_rocket_1)\n",
    "\n",
    "init_lb_rocket_2 = [−0.2, 0.02, -0.5, -1., 6*pi/180, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n",
    "init_ub_rocket_2 = [0.2, 0.5, 0.5, 1., 20*pi/180, 0.2, 0.0, 0.0, 1.0, 1.0, 15*pi/180]\n",
    "rocket_input_2 = interval.(init_lb_rocket_2,init_ub_rocket_2)\n",
    "\n",
    "# Unsafety (from https://github.com/V2A2/StarV/blob/main/StarV/util/load.py#L168)\n",
    "\n",
    "mat_spec_1 = [[0.0 -1.0 0.0]\n",
    "              [0.0 0.0 -1.0]]\n",
    "rhs_spec_1 = [0.0, 0.0]\n",
    "\n",
    "mat_spec_2 = [[0.0 1.0 0.0]\n",
    "              [0.0 0.0 1.0]]\n",
    "rhs_spec_2 = [0.0, 0.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#vect_nb_focal_elem = [5, 11, 8, 13, 8, 6, 1, 1, 2, 2, 2]\n",
    "vect_nb_focal_elem = [7, 12, 10, 17, 9, 7, 1, 1, 2, 1, 1]\n",
    "\n",
    "# the true flag in init_pbox_Normal means truncating the focal elements to restrict the range to [lb,ub]\n",
    "rocket_inputpbox_1 = init_pbox_Normal(init_lb_rocket_1,init_ub_rocket_1,vect_nb_focal_elem,true)\n",
    "rocket_inputpbox_2 = init_pbox_Normal(init_lb_rocket_2,init_ub_rocket_2,vect_nb_focal_elem,true)\n",
    "\n",
    "@time vec_proba = dsz_approximate_nnet_and_condition_nostorage(rocket_lander_nnet0, rocket_inputpbox_1, mat_spec_1,rhs_spec_1,false) \n",
    "println(\"Rocket lander nnet0, Prop 1, vect_nb_focal_elem=\",vect_nb_focal_elem,\": \",vec_proba[length(vec_proba)])\n",
    "\n",
    "@time vec_proba = dsz_approximate_nnet_and_condition_nostorage(rocket_lander_nnet0, rocket_inputpbox_2, mat_spec_2,rhs_spec_2,false) \n",
    "println(\"Rocket lander nnet0, Prop 2, vect_nb_focal_elem=\",vect_nb_focal_elem,\": \",vec_proba[length(vec_proba)])\n",
    "\n",
    "@time vec_proba = dsz_approximate_nnet_and_condition_nostorage(rocket_lander_nnet1, rocket_inputpbox_1, mat_spec_1,rhs_spec_1,false) \n",
    "println(\"Rocket lander nnet1, Prop 1, vect_nb_focal_elem=\",vect_nb_focal_elem,\": \",vec_proba[length(vec_proba)])\n",
    "\n",
    "@time vec_proba = dsz_approximate_nnet_and_condition_nostorage(rocket_lander_nnet1, rocket_inputpbox_2, mat_spec_2,rhs_spec_2,false) \n",
    "println(\"Rocket lander nnet1, Prop 2, vect_nb_focal_elem=\",vect_nb_focal_elem,\":\",vec_proba[length(vec_proba)])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (11 threads) 1.11.2",
   "language": "julia",
   "name": "julia-_11-threads_-1.11"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
