{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbca90f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise\n",
    "using LazySets\n",
    "using DifferentialEquations\n",
    "using LazySets\n",
    "using ProgressMeter\n",
    "using ProgressBars\n",
    "using JLD2\n",
    "using Flux\n",
    "using LinearAlgebra\n",
    "using Zygote\n",
    "using ReverseDiff\n",
    "using Plots\n",
    "using Statistics\n",
    "using Optimisers, ParameterSchedulers\n",
    "using RobotDynamics\n",
    "using RobotZoo\n",
    "using Random\n",
    "using Rotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d229df6b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eda60602",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "generate_random_traj (generic function with 1 method)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using RobotZoo\n",
    "import RobotDynamics as RD\n",
    "include(\"quadrotor_euler.jl\")\n",
    "\n",
    "function random_point_in_hyperrectangle(hyperrectangle::Hyperrectangle, non_admissible_area=nothing;q=false)\n",
    "    dimensions = dim(hyperrectangle)\n",
    "    random_point = zeros(dimensions)\n",
    "    for i in 1:dimensions\n",
    "        random_point[i] = rand() * (high(hyperrectangle, i)-low(hyperrectangle, i)) + low(hyperrectangle, i)\n",
    "    end\n",
    "    if q\n",
    "        model = RobotZoo.Quadrotor()\n",
    "        x,u = rand(model)\n",
    "        random_point[4:7] .= x[4:7]\n",
    "#         random_point[4:7] .= [1,0,0,0]\n",
    "    end\n",
    "    isnothing(non_admissible_area) && return random_point, true\n",
    "    (random_point ∉ non_admissible_area) && return random_point, true\n",
    "    return random_point, false\n",
    "end\n",
    "\n",
    "function generate_Xref(dmodel, x_0, dt, T, X, X_unsafe, U; max_u=10000,euler=false)\n",
    "    n_steps = Int(floor(T / dt))\n",
    "    Uref = []\n",
    "    Xref = []\n",
    "    push!(Xref, x_0)\n",
    "    for i in 1:n_steps\n",
    "        u = nothing\n",
    "        x = Xref[end]\n",
    "        x′ = nothing\n",
    "        feasible = false\n",
    "        for j in 1:max_u\n",
    "            u, _ = random_point_in_hyperrectangle(U)\n",
    "#             @show x, u\n",
    "            if euler\n",
    "                f(x, p, t) = quadrotor_dynamics_euler!(x, u)\n",
    "                tspan = (0.0, dt)\n",
    "                prob = ODEProblem(f , x, tspan)\n",
    "                sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)\n",
    "                x′ = sol[end]\n",
    "            else\n",
    "                x′ = RD.discrete_dynamics(dmodel, x, u, 0.0, dt)\n",
    "            end\n",
    "#             @show x′\n",
    "            if (x′ ∉ X_unsafe) && (x′ ∈ X)\n",
    "                feasible = true\n",
    "                break\n",
    "            end\n",
    "            # (x′ ∉ X_unsafe) && break\n",
    "        end\n",
    "        # @show x, u\n",
    "        # @show x′\n",
    "        if !feasible\n",
    "            (length(Uref)==1) && (return Xref, Uref)\n",
    "            (length(Xref)==1) && (return Xref, Uref)\n",
    "#             @show length(Xref), length(Uref), Xref\n",
    "            pop!(Xref)\n",
    "            pop!(Uref)\n",
    "            continue\n",
    "        end\n",
    "        push!(Xref, x′)\n",
    "        push!(Uref, u)\n",
    "    end\n",
    "    return Xref, Uref\n",
    "end\n",
    "\n",
    "function generate_random_traj(dmodel, num, dt, T,X, X_unsafe, U;q=false,euler=false)\n",
    "    Xrefs = []\n",
    "    Urefs = []\n",
    "    @showprogress for i = 1:num\n",
    "        x_0 = nothing\n",
    "        while true\n",
    "            x_0, safe_flag = random_point_in_hyperrectangle(X, X_unsafe;q=q)\n",
    "            safe_flag && break\n",
    "        end\n",
    "        \n",
    "        Xref, Uref = generate_Xref(dmodel, x_0, dt, T, X, X_unsafe, U;euler=euler)\n",
    "        push!(Xrefs, Xref)\n",
    "        push!(Urefs, Uref)\n",
    "    end\n",
    "    return Xrefs, Urefs\n",
    "end\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f9ea4e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a=[1,2,3]\n",
    "# pop!(a)\n",
    "# @show a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0182cac5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "plot_function (generic function with 1 method)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function plot_function(Xrefs; n_ignore=50,q=false)\n",
    "    # p = plot()\n",
    "    plt1 = plot(Hyperrectangle(low=low(X)[1:2], high=high(X)[1:2]))\n",
    "    plot!(plt1, Hyperrectangle(low=low(X_unsafe)[1:2], high=high(X_unsafe)[1:2]), fillcolor=:red)\n",
    "#     @show length(Xrefs), length(Urefs[1])\n",
    "    valid_num = 0\n",
    "    for k = 1:length(Xrefs)\n",
    "        if length(Urefs[k])<n_ignore+1\n",
    "            continue\n",
    "        end\n",
    "#         @show length(Urefs[k]), length(Xrefs[k])\n",
    "        @assert length(Urefs[k]) == (length(Xrefs[k]) - 1) \n",
    "        \n",
    "        xs = [Xrefs[k][i][1] for i in 1:length(Urefs[k])-n_ignore]\n",
    "        ys = [Xrefs[k][i][2] for i in 1:length(Urefs[k])-n_ignore]\n",
    "        # @show length(xs)\n",
    "        if q\n",
    "            zs = [Xrefs[k][i][3] for i in 1:length(Urefs[k])-n_ignore]\n",
    "            plot!(xs, ys,zs, legend = false)\n",
    "        else\n",
    "            plot!(xs, ys, legend = false)\n",
    "        end\n",
    "        valid_num += length(Urefs[k]) - n_ignore\n",
    "    end\n",
    "    display(plt1)\n",
    "    @show valid_num\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c6a7591c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "build_dataset (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function build_dataset(Xrefs, Urefs, X, X_unsafe, U; n_ignore=50,q=false)\n",
    "    data = []\n",
    "    for k = 1:length(Xrefs)\n",
    "        if length(Urefs[k]) < n_ignore+1\n",
    "            continue\n",
    "        end\n",
    "        for i in 1:length(Urefs[k])-n_ignore\n",
    "            push!(data, [Xrefs[k][i], Urefs[k][i],[true]]) # safe and persistently feasible\n",
    "        end\n",
    "    end\n",
    "    n_safe = Int(floor(length(data)*0.8))\n",
    "    for i in 1:n_safe\n",
    "        random_x0, safe_flag = random_point_in_hyperrectangle(X_unsafe, X_unsafe;q=q)\n",
    "        random_u0, _ = random_point_in_hyperrectangle(U)\n",
    "        @assert safe_flag==false\n",
    "        push!(data, [random_x0, random_u0, [safe_flag]])\n",
    "    end\n",
    "    \n",
    "    data = reduce(hcat,data)\n",
    "#     @show n_safe, size(data)\n",
    "    # @show size(reduce(hcat,data))\n",
    "    # @show size(data),size(data[1]),size(data[2]),size(data[3]),size(data[4])\n",
    "    shuffled_indices = shuffle(1:size(data, 2))\n",
    "    data = data[:, shuffled_indices]\n",
    "    training_data = data[:, 1:end-10000]\n",
    "    test_data = data[:, end-10000:end]\n",
    "    save_object(\"quadrotorEuler_seq_training_data.jld2\", training_data)\n",
    "    save_object(\"quadrotorEuler_seq_test_data.jld2\", test_data)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "210a6468",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dyn_model = RobotZoo.PlanarQuadrotor()\n",
    "n,m = RD.dims(dyn_model)\n",
    "# @show n,m\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)\n",
    "# @show dmodel\n",
    "\n",
    "X = Hyperrectangle(low = [0, 0, -0.1, -1, -1 ,-1], high = [4,4, 0.1, 1,1,1])\n",
    "U = Hyperrectangle(low = [4, 4], high = [6,6])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,-0.1,-1, -1 ,-1], high = [2.5,2, 0.1, 1,1,1])\n",
    "\n",
    "\n",
    "Xrefs, Urefs = generate_random_traj(dmodel, 500000, 0.1, 10, X, X_unsafe, U);\n",
    "plot_function(Xrefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a4bef7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "X = Hyperrectangle(low = [0, 0,0, -1, -1, -1, -10, -10, -10, -10, -10, -10], high = [4,4,4, 1,1,1,10,10,10,10,10,10])\n",
    "U = Hyperrectangle(low = [-10, -10,-10, -10], high = [10,10,10,10])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,0,-1, -1, -1, -10, -10, -10, -10, -10, -10], high = [2.5,2,4, 1,1,1,10,10,10,10,10,10])\n",
    "\n",
    "Xrefs, Urefs = generate_random_traj(dmodel, 50, 0.005, 0.5, X, X_unsafe, U;euler=true);\n",
    "plot_function(Xrefs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f78605",
   "metadata": {},
   "outputs": [],
   "source": [
    "# old, quaternion not friendly\n",
    "dyn_model = RobotZoo.Quadrotor()\n",
    "n,m = RD.dims(dyn_model)\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)\n",
    "\n",
    "@show n,m\n",
    "# @assert 1==2\n",
    "X = Hyperrectangle(low = [0, 0,0, -1, -1, -1, -1, -10, -10, -10, -10, -10, -10], high = [4,4,4, 1,1,1,1,10,10,10,10,10,10])\n",
    "U = Hyperrectangle(low = [-10, -10,-10, -10], high = [10,10,10,10])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,0,-1, -1, -1, -1, -10, -10, -10, -10, -10, -10], high = [2.5,2,4, 1,1,1,1,10,10,10,10,10,10])\n",
    "\n",
    "\n",
    "# A = [0. 0 1 0;\n",
    "#     0 0 0 1;\n",
    "#     0 0 0 0;\n",
    "#     0 0 0 0;]\n",
    "# B = [0. 0;\n",
    "#     0 0;\n",
    "#     1 0;\n",
    "#     0 1;]\n",
    "\n",
    "\n",
    "# function plot_function(Xrefs)\n",
    "#     # p = plot()\n",
    "#     plt1 = plot(Hyperrectangle(low=low(X)[1:2], high=high(X)[1:2]))\n",
    "#     plot!(plt1, Hyperrectangle(low=low(X_unsafe)[1:2], high=high(X_unsafe)[1:2]), fillcolor=:red)\n",
    "#     @show length(Xrefs), length(Urefs[1])\n",
    "#     for k = 1:length(Xrefs)\n",
    "# #         @assert length(Xrefs[k])==101\n",
    "#         xs = [Xrefs[k][i][1] for i in 1:length(Xrefs[k])]\n",
    "#         ys = [Xrefs[k][i][2] for i in 1:length(Xrefs[k])]\n",
    "#         # @show length(xs)\n",
    "#         plot!(xs, ys, legend = false)\n",
    "#     end\n",
    "#     display(plt1)\n",
    "# end\n",
    "Xrefs, Urefs = generate_random_traj(dmodel, 500000, 0.005, 0.5, X, X_unsafe, U;q=true);\n",
    "# plot_function(Xrefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bea0c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @show Xrefs, Urefs\n",
    "plot_function(Xrefs;q=false)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91342f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(Xrefs;q=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0c1e8c08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(n, m) = (13, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Hyperrectangle{Float64, Vector{Float64}, Vector{Float64}}([2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dyn_model = RobotZoo.Quadrotor()\n",
    "n,m = RD.dims(dyn_model)\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)\n",
    "\n",
    "@show n,m\n",
    "# @assert 1==2\n",
    "X = Hyperrectangle(low = [0, 0,0, -1, -1, -1, -1, -10, -10, -10, -10, -10, -10], high = [4,4,4, 1,1,1,1,10,10,10,10,10,10])\n",
    "U = Hyperrectangle(low = [-10, -10,-10, -10], high = [10,10,10,10])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,0,-1, -1, -1, -1, -10, -10, -10, -10, -10, -10], high = [2.5,2,4, 1,1,1,1,10,10,10,10,10,10])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "49a9b4ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size(safe_data) = (3, 2254154)\n",
      "(n_safe, size(data)) = (1.8033232000000002e6, (3, 1803323))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1.8033232000000002e6, (3, 1803323))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = load_object(\"quadrotor_seq_training_data.jld2\")\n",
    "y_init = data[3, :]\n",
    "# @show y_init\n",
    "index = findall(x->x[1]==1, y_init)\n",
    "# @show index\n",
    "\n",
    "safe_data = data[:, index]\n",
    "@show size(safe_data)\n",
    "\n",
    "data = []\n",
    "n_safe = size(safe_data,2)*0.8\n",
    "for i in 1:n_safe\n",
    "    random_x0, safe_flag = random_point_in_hyperrectangle(X_unsafe, X_unsafe;q=true)\n",
    "    random_u0, _ = random_point_in_hyperrectangle(U)\n",
    "    @assert safe_flag==false\n",
    "    push!(data, [random_x0, random_u0, [safe_flag]])\n",
    "end\n",
    "\n",
    "data = reduce(hcat,data)\n",
    "@show n_safe, size(data)\n",
    "# @show size(reduce(hcat,data))\n",
    "# @show size(data),size(data[1]),size(data[2]),size(data[3]),size(data[4])\n",
    "# shuffled_indices = shuffle(1:size(data, 2))\n",
    "# data = data[:, shuffled_indices]\n",
    "# training_data = data[:, 1:end-10000]\n",
    "# test_data = data[:, end-10000:end]\n",
    "# save_object(\"quadrotor_seq_training_data.jld2\", training_data)\n",
    "# save_object(\"quadrotor_seq_test_data.jld2\", test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6dd53c8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_data[1, 1] = [1.8116177138347822, 1.626291831324185, 0.7685982025292071, -0.7197878830984011, -0.6909824888845153, -0.03788420253387495, 0.05489435852667852, -3.3175248703697235, -2.6668563763398767, -6.162063132398439, 8.50717718768993, 2.6217549957835633, 3.5646297528580746]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "13-element Vector{Float64}:\n",
       "  1.8116177138347822\n",
       "  1.626291831324185\n",
       "  0.7685982025292071\n",
       " -0.7197878830984011\n",
       " -0.6909824888845153\n",
       " -0.03788420253387495\n",
       "  0.05489435852667852\n",
       " -3.3175248703697235\n",
       " -2.6668563763398767\n",
       " -6.162063132398439\n",
       "  8.50717718768993\n",
       "  2.6217549957835633\n",
       "  3.5646297528580746"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1a08df4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = hcat(safe_data, data)\n",
    "shuffled_indices = shuffle(1:size(all_data, 2))\n",
    "all_data = all_data[:, shuffled_indices]\n",
    "training_data = all_data[:, 1:end-10000]\n",
    "test_data = all_data[:, end-10000:end]\n",
    "save_object(\"quadrotor_seq_training_data.jld2\", training_data)\n",
    "save_object(\"quadrotor_seq_test_data.jld2\", test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ccc1a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "data = load_object(\"quadrotor_seq_training_data_old.jld2\")\n",
    "shuffled_indices = shuffle(1:size(data, 2))\n",
    "data = data[:, shuffled_indices]\n",
    "@show data[3,1:100]\n",
    "save_object(\"quadrotor_seq_training_data.jld2\", data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2c7b416e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data[3, 1:100] = AbstractVector{Float64}[[1.0], [0.0], [1.0], [0.0], [1.0], [1.0], [1.0], [0.0], [1.0], [1.0], [0.0], [1.0], [1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [0.0], [0.0], [0.0], [1.0], [0.0], [1.0], [1.0], [1.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [1.0], [0.0], [0.0], [0.0], [1.0], [0.0], [1.0], [0.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0], [1.0], [0.0], [1.0], [1.0], [0.0], [0.0], [0.0], [1.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0], [0.0], [1.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0], [0.0], [0.0], [0.0], [1.0], [0.0], [1.0], [0.0], [0.0], [1.0], [1.0], [0.0], [1.0], [1.0], [1.0], [0.0], [1.0], [0.0], [1.0], [0.0], [1.0]]\n"
     ]
    }
   ],
   "source": [
    "build_dataset(Xrefs, Urefs, X, X_unsafe, U;q=true)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "fd7fbd42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(size(training_data), size(test_data)) = ((3, 4047477), (3, 10001))\n",
      "(size(training_data[1, :]), size(test_data[1, :])) = ((4047477,), (10001,))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "((4047477,), (10001,))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_data = load_object(\"quadrotor_seq_training_data.jld2\")\n",
    "test_data = load_object(\"quadrotor_seq_test_data.jld2\")\n",
    "@show size(training_data), size(test_data)\n",
    "@show size(training_data[1,:]), size(test_data[1,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "50509d5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(size(test_data[1, 1]))[1] = 12\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@show size(test_data[1,1])[1] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "8e445ea7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:50\u001b[39m:20\u001b[39m\n"
     ]
    }
   ],
   "source": [
    "function rotation_matrix_to_euler(R::AbstractMatrix{T}) where T<:Real\n",
    "    # Assuming the rotation matrix R is 3x3\n",
    "    pitch = -asin(R[3, 1])\n",
    "    \n",
    "    if cos(pitch) ≈ 0\n",
    "        # Gimbal lock case\n",
    "        roll = 0.0\n",
    "        yaw = atan(R[1, 2], R[2, 2])\n",
    "    else\n",
    "        roll = atan(R[3, 2] / cos(pitch), R[3, 3] / cos(pitch))\n",
    "        yaw = atan(R[2, 1] / cos(pitch), R[1, 1] / cos(pitch))\n",
    "    end\n",
    "\n",
    "    return roll, pitch, yaw\n",
    "end\n",
    "\n",
    "\n",
    "old = nothing\n",
    "@showprogress for i in 1:4047477\n",
    "    if i > 1\n",
    "        @assert size(training_data[1, i-1])[1] == 12\n",
    "    end\n",
    "    @assert size(training_data[1, i])[1] == 13\n",
    "    old = training_data[1, i]\n",
    "    q = QuatRotation(old[4],old[5],old[6],old[7])  # Replace with your quaternion values (w, x, y, z)\n",
    "\n",
    "    # Convert quaternion to rotation matrix\n",
    "    R = RotMatrix(q)\n",
    "#     @show R, RotXYZ(R)\n",
    "    # @show Rotations.params(R)\n",
    "#     roll = atan(R[3, 2], R[3, 3])\n",
    "#     pitch = -asin(R[3, 1])\n",
    "#     yaw = atan(R[2, 1], R[1, 1])\n",
    "#     @show roll, pitch, yaw\n",
    "#     @show rpy2rotmat([roll; pitch; yaw;])\n",
    "    roll, pitch, yaw = rotation_matrix_to_euler(R)\n",
    "\n",
    "    training_data[1, i] = [old[1:3]..., roll, pitch, yaw, old[8:13]...]\n",
    "    \n",
    "end\n",
    "\n",
    "old = nothing\n",
    "@showprogress for i in 1:10001\n",
    "    if i > 1\n",
    "        @assert size(test_data[1, i-1])[1] == 12\n",
    "    end\n",
    "    @assert size(test_data[1, i])[1] == 13\n",
    "    old = test_data[1, i]\n",
    "    q = QuatRotation(old[4],old[5],old[6],old[7])  # Replace with your quaternion values (w, x, y, z)\n",
    "\n",
    "    # Convert quaternion to rotation matrix\n",
    "    R = RotMatrix(q)\n",
    "#     @show R, RotXYZ(R)\n",
    "    # @show Rotations.params(R)\n",
    "#     roll = atan(R[3, 2], R[3, 3])\n",
    "#     pitch = -asin(R[3, 1])\n",
    "#     yaw = atan(R[2, 1], R[1, 1])\n",
    "#     @show roll, pitch, yaw\n",
    "#     @show rpy2rotmat([roll; pitch; yaw;])\n",
    "    roll, pitch, yaw = rotation_matrix_to_euler(R)\n",
    "\n",
    "    test_data[1, i] = [old[1:3]..., roll, pitch, yaw, old[8:13]...]\n",
    "    \n",
    "end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "22776fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_object(\"quadrotorEuler_seq_training_data.jld2\", training_data)\n",
    "save_object(\"quadrotorEuler_seq_test_data.jld2\", test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "15672c12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:26\u001b[39m\n"
     ]
    }
   ],
   "source": [
    "training_data = load_object(\"quadrotorEuler_seq_training_data.jld2\")\n",
    "test_data = load_object(\"quadrotorEuler_seq_test_data.jld2\")\n",
    "X = Hyperrectangle(low = [0, 0,0, -0.1,-0.1,-0.1, -1, -1, -1, -1, -1, -1], high = [4,4,4, 0.1,0.1,0.1,1,1,1,1,1,1])\n",
    "U = Hyperrectangle(low = [1, 1,1, 1], high = [1.5,1.5,1.5,1.5])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,0,-0.1,-0.1,-0.1, -1, -1, -1, -1, -1, -1], high = [2.5,2,4, 0.1,0.1,0.1,1,1,1,1,1,1])\n",
    "# euler_range=Hyperrectangle(low = [-0.1,-0.1,-0.1], high = [0.1,0.1,0.1])\n",
    "@showprogress for i in 1:4047477\n",
    "\n",
    "    @assert size(training_data[1, i])[1] == 12\n",
    "    old = training_data[1, i]\n",
    "    if old ∈ X\n",
    "        @show i\n",
    "    end\n",
    "#     q = QuatRotation(old[4],old[5],old[6],old[7])  # Replace with your quaternion values (w, x, y, z)\n",
    "\n",
    "#     # Convert quaternion to rotation matrix\n",
    "#     R = RotMatrix(q)\n",
    "# #     @show R, RotXYZ(R)\n",
    "#     # @show Rotations.params(R)\n",
    "# #     roll = atan(R[3, 2], R[3, 3])\n",
    "# #     pitch = -asin(R[3, 1])\n",
    "# #     yaw = atan(R[2, 1], R[1, 1])\n",
    "# #     @show roll, pitch, yaw\n",
    "# #     @show rpy2rotmat([roll; pitch; yaw;])\n",
    "#     roll, pitch, yaw = rotation_matrix_to_euler(R)\n",
    "\n",
    "#     training_data[1, i] = [old[1:3]..., roll, pitch, yaw, old[8:13]...]\n",
    "    \n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1618bc63",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dyn_model = RobotZoo.DubinsCar()\n",
    "n,m = RD.dims(dyn_model)\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)\n",
    "\n",
    "X = Hyperrectangle(low = [0, 0, 0], high = [4,4, π])\n",
    "U = Hyperrectangle(low = [-1, -1], high = [1,1])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0,0], high = [2.5,2, π])\n",
    "\n",
    "# A = [0. 0 1 0;\n",
    "#     0 0 0 1;\n",
    "#     0 0 0 0;\n",
    "#     0 0 0 0;]\n",
    "# B = [0. 0;\n",
    "#     0 0;\n",
    "#     1 0;\n",
    "#     0 1;]\n",
    "\n",
    "\n",
    "# function plot_function(Xrefs)\n",
    "#     # p = plot()\n",
    "#     plt1 = plot(Hyperrectangle(low=low(X)[1:2], high=high(X)[1:2]))\n",
    "#     plot!(plt1, Hyperrectangle(low=low(X_unsafe)[1:2], high=high(X_unsafe)[1:2]), fillcolor=:red)\n",
    "#     @show length(Xrefs), length(Urefs[1])\n",
    "#     for k = 1:length(Xrefs)\n",
    "# #         @assert length(Xrefs[k])==101\n",
    "#         xs = [Xrefs[k][i][1] for i in 1:length(Xrefs[k])]\n",
    "#         ys = [Xrefs[k][i][2] for i in 1:length(Xrefs[k])]\n",
    "#         # @show length(xs)\n",
    "#         plot!(xs, ys, legend = false)\n",
    "#     end\n",
    "#     display(plt1)\n",
    "# end\n",
    "Xrefs, Urefs = generate_random_traj(dmodel, 50000, 0.1, 10, X, X_unsafe, U);\n",
    "plot_function(Xrefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdc9628",
   "metadata": {},
   "outputs": [],
   "source": [
    "build_dataset(Xrefs, Urefs, X, X_unsafe, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d65232",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dyn_model = RobotZoo.DoubleIntegrator(2)\n",
    "n,m = RD.dims(dyn_model)\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(dyn_model)\n",
    "\n",
    "# include(\"dataset.jl\")\n",
    "X = Hyperrectangle(low = [0, 0, -1, -1], high = [4,4, 1, 1])\n",
    "U = Hyperrectangle(low = [-1, -1], high = [1,1])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0, -1, -1], high = [2.5,2, 1, 1])\n",
    "\n",
    "\n",
    "\n",
    "Xrefs, Urefs = generate_random_traj(dmodel, 500000, 0.1, 10, X, X_unsafe, U);\n",
    "# plot_function(Xrefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b491e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "build_dataset(Xrefs, Urefs, X, X_unsafe, U)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274c0ad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_function(Xrefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "dbfcdc84",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "minimum(x_list4, dims = 2) = [-3.129672478071869;;]\n",
      "maximum(x_list4, dims = 2) = [3.124800240395843;;]\n",
      "minimum(x_list5, dims = 2) = [-1.5464243166989073;;]\n",
      "maximum(x_list5, dims = 2) = [1.5207035832355726;;]\n",
      "minimum(x_list6, dims = 2) = [-3.1281365312920553;;]\n",
      "maximum(x_list6, dims = 2) = [3.139715551318598;;]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1×1 Matrix{Float64}:\n",
       " 3.139715551318598"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# training_data = load_object(\"quadrotorEuler_seq_training_data.jld2\")\n",
    "test_data = load_object(\"quadrotorEuler_seq_test_data.jld2\")\n",
    "x_list4 = []\n",
    "x_list5 = []\n",
    "x_list6 = []\n",
    "for i in 1:1000\n",
    "#     @show size(test_data[1,i])\n",
    "    push!(x_list4, test_data[1,i][4])\n",
    "    push!(x_list5, test_data[1,i][5])\n",
    "    push!(x_list6, test_data[1,i][6])\n",
    "end\n",
    "x_list4 = cat(x_list4..., dims=2)\n",
    "x_list5 = cat(x_list5..., dims=2)\n",
    "x_list6 = cat(x_list6..., dims=2)\n",
    "\n",
    "@show minimum(x_list4, dims=2)\n",
    "@show maximum(x_list4, dims=2)\n",
    "@show minimum(x_list5, dims=2)\n",
    "@show maximum(x_list5, dims=2)\n",
    "@show minimum(x_list6, dims=2)\n",
    "@show maximum(x_list6, dims=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "344f499d",
   "metadata": {},
   "outputs": [],
   "source": [
    "using RobotZoo\n",
    "import RobotDynamics as RD\n",
    "# using StaticArrays\n",
    "\n",
    "model = RobotZoo.Quadrotor()#.DoubleIntegrator(2)#DubinsCar()#.Quadrotor()#\n",
    "n,m = RD.dims(model)\n",
    "\n",
    "# Generate random state and control vector\n",
    "x,u = rand(model)\n",
    "t = 0.0   # time (s)\n",
    "dt = 0.001  # time step (s)\n",
    "z = RD.KnotPoint(x,u,t,dt)\n",
    "\n",
    "# Evaluate the continuous dynamics and Jacobian\n",
    "ẋ = RobotDynamics.dynamics(model, x, u)\n",
    "∇f = zeros(n, n + m)\n",
    "RD.jacobian!(RD.StaticReturn(), RD.ForwardAD(), model, ∇f, zeros(n), z)\n",
    "@show ẋ\n",
    "\n",
    "f(x, p, t) = RobotDynamics.dynamics(model, x, u)\n",
    "tspan = (0.0, dt)\n",
    "prob = ODEProblem(f , x, tspan)\n",
    "sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)\n",
    "@show sol[end]\n",
    "\n",
    "A = ∇f[:, 1:n]\n",
    "B = ∇f[:, n+1:end]\n",
    "@show linear_x_dot = A * x + B * u\n",
    "@show typeof(∇f)\n",
    "@show typeof(linear_x_dot[1])\n",
    "@show x + linear_x_dot * dt\n",
    "\n",
    "# Evaluate the discrete dynamics and Jacobian\n",
    "dmodel = RD.DiscretizedDynamics{RD.RK4}(model)\n",
    "@show x, u\n",
    "x′ = RD.discrete_dynamics(dmodel, x, u, t, dt)\n",
    "# RD.jacobian!(RD.StaticReturn(), RD.ForwardAD(), dmodel, ∇f, x′, z)\n",
    "@show x′\n",
    "x = x′\n",
    "x′ = RD.discrete_dynamics(dmodel, x, u, t, dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad91528",
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"dataset.jl\")\n",
    "X = Hyperrectangle(low = [-1, π/2, -1, -1], high = [1,3*π/2, 1, 1])\n",
    "U = Hyperrectangle(low = [-2], high = [2])\n",
    "X_unsafe = Hyperrectangle(low = [0, π/4, -1, -1], high = [1,3*π/4, 1, 1])\n",
    "\n",
    "A = [0. 0 1 0;\n",
    "    0 0 0 1;\n",
    "    0 0 0 0;\n",
    "    0 0 0 0;]\n",
    "B = [0. 0;\n",
    "    0 0;\n",
    "    1 0;\n",
    "    0 1;]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20cbbfb9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77407cba",
   "metadata": {},
   "outputs": [],
   "source": [
    "using RobotZoo\n",
    "import RobotDynamics as RD\n",
    "using StaticArrays\n",
    "\n",
    "model = RobotZoo.Cartpole()\n",
    "n,m = RD.dims(model)\n",
    "@show n,m\n",
    "# Generate random state and control vector\n",
    "# x,u = rand(model)\n",
    "x0 = @SVector [0,4*π/4, 1, 1] # states are position, angle (vertical=π, forward horizontal = π/2), velocity, angular velocity\n",
    "u0 = @SVector [1.]\n",
    "# @show x0[@SVector [1, 2]]\n",
    "# @show x, u\n",
    "# t = 0.0   # time (s)\n",
    "# dt = 0.1  # time step (s)\n",
    "# z = RD.KnotPoint(x,u,t,dt)\n",
    "# @show z\n",
    "\n",
    "# # Evaluate the continuous dynamics and Jacobian\n",
    "# ẋ = RobotDynamics.dynamics(model, x, u)\n",
    "# @show ẋ\n",
    "\n",
    "f(x, p, t) = RobotDynamics.dynamics(model, x, u0)\n",
    "tspan = (0.0, 0.1)\n",
    "prob = ODEProblem(f , x0, tspan)\n",
    "sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8)\n",
    "@show sol[end]\n",
    "\n",
    "\n",
    "\n",
    "∇f = zeros(n, n + m)\n",
    "@show  ∇f\n",
    "# RobotDynamics.jacobian!(RD.StaticReturn(), RD.ForwardAD(), model, ∇f, ẋ, z)\n",
    "# ẋ = RobotDynamics.dynamics(model, x, u)\n",
    "∇f = zeros(n, n + m)\n",
    "\n",
    "# RobotDynamics.jacobian!(StaticReturn(), ForwardAD(), model, ∇c1, xdot, z)\n",
    "RobotDynamics.jacobian!(RD.StaticReturn(), RD.ForwardAD(), model, ∇f, zeros(n), z)\n",
    "# @show size(∇f),size(z)\n",
    "# @show  ∇f * z\n",
    "A = ∇f[:, 1:n]\n",
    "B = ∇f[:, n+1:end]\n",
    "@show A * x + B * u\n",
    "\n",
    "\n",
    "# # Evaluate the discrete dynamics and Jacobian\n",
    "# dmodel = RD.DiscretizedDynamics{RD.RK4}(model)\n",
    "# x′ = RD.discrete_dynamics(dmodel, x, u, t, dt)\n",
    "# RD.discrete_jacobian!(RD.StaticReturn(), RD.ForwardAD(), dmodel, ∇f, x′, z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73ce5d0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "const RD = RobotDynamics\n",
    "using RobotDynamics: KnotPoint, dynamics, dynamics!, jacobian!\n",
    "using RobotDynamics: StaticReturn, InPlace, ForwardAD, FiniteDifference\n",
    "using Random\n",
    "using Base\n",
    "function test_model(model; evals=1, samples=1, tol=1e-6, customjacobian=false)\n",
    "    println(typeof(model))\n",
    "    dmodel = RD.DiscretizedDynamics{RD.RK4}(model)\n",
    "    t, dt = 1.1, 0.1\n",
    "    Random.seed!(1)\n",
    "    x, u = rand(model)\n",
    "    n, m = RD.dims(model)\n",
    "    z = KnotPoint(x, u, t, dt)\n",
    "    ∇c1 = zeros(n, n + m)\n",
    "    ∇c2 = zeros(n, n + m)\n",
    "    xdot = zeros(n)\n",
    "    # allocs = 0\n",
    "    # allocs += @allocated RobotDynamics.dynamics($model, $x, $u) evals = evals samples = samples\n",
    "    # allocs += @allocated RobotDynamics.dynamics!($model, $xdot, $x, $u) evals = evals samples = samples\n",
    "    # @assert xdot == RobotDynamics.dynamics(model, x, u)\n",
    "    @show xdot, RobotDynamics.dynamics(model, x, u)\n",
    "    # @test allocs == 0\n",
    "    RobotDynamics.jacobian!(StaticReturn(), ForwardAD(), model, ∇c1, xdot, z)\n",
    "    RobotDynamics.jacobian!(StaticReturn(), FiniteDifference(), model, ∇c2, xdot, z)\n",
    "    @show ∇c1 * z, ∇c2\n",
    "    @assert ∇c1 ≈ ∇c2 atol = tol\n",
    "    RobotDynamics.jacobian!(InPlace(), ForwardAD(), model, ∇c2, xdot, z)\n",
    "    @assert ∇c1 ≈ ∇c2\n",
    "    RobotDynamics.jacobian!(InPlace(), FiniteDifference(), model, ∇c1, xdot, z)\n",
    "    @assert ∇c1 ≈ ∇c2 atol = tol\n",
    "    if customjacobian\n",
    "        RobotDynamics.jacobian!(InPlace(), RD.UserDefined(), model, ∇c1, xdot, z)\n",
    "        @assert ∇c1 ≈ ∇c2\n",
    "    end\n",
    "end\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b9a8c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cartpole = RobotZoo.Cartpole()\n",
    "@assert RD.dims(cartpole) == (4, 1, 4)\n",
    "test_model(cartpole)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "354bd652",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Acrobot\n",
    "acrobot = RobotZoo.Acrobot()\n",
    "@test RD.dims(acrobot) == (4, 1, 4)\n",
    "test_model(acrobot)\n",
    "\n",
    "# Car\n",
    "car = RobotZoo.DubinsCar()\n",
    "@test RD.dims(car) == (3, 2, 3)\n",
    "test_model(car, customjacobian=true)\n",
    "\n",
    "# Bicycle Car\n",
    "bicycle = RobotZoo.BicycleModel()\n",
    "@test RD.dims(bicycle) == (4, 2, 4)\n",
    "test_model(bicycle)\n",
    "\n",
    "bicycle = RobotZoo.BicycleModel(ref=:rear)\n",
    "@test RD.dims(bicycle) == (4, 2, 4)\n",
    "test_model(bicycle)\n",
    "\n",
    "# Rover\n",
    "rover = RobotZoo.Rover()\n",
    "@test RD.dims(rover) == (5, 2, 5)\n",
    "test_model(rover)\n",
    "\n",
    "# Planar Rocket\n",
    "rocket = RobotZoo.PlanarRocket()\n",
    "@test RD.dims(rocket) == (8, 2, 8)\n",
    "test_model(rocket)\n",
    "\n",
    "# Planar Quad\n",
    "quad = RobotZoo.PlanarQuadrotor()\n",
    "@test RD.dims(quad) == (6, 2, 6)\n",
    "test_model(quad)\n",
    "\n",
    "# Cartpole\n",
    "cartpole = RobotZoo.Cartpole()\n",
    "@test RD.dims(cartpole) == (4, 1, 4)\n",
    "test_model(cartpole)\n",
    "\n",
    "# Double Integrator\n",
    "dim = 3\n",
    "di = RobotZoo.DoubleIntegrator(dim)\n",
    "n, m = RD.dims(di)\n",
    "@test (n, m) == (6, 3)\n",
    "test_model(di)\n",
    "\n",
    "# Pendulum\n",
    "pend = RobotZoo.Pendulum()\n",
    "RobotZoo.Pendulum{Float64}(1, 1, 1, 1, 1, 1)\n",
    "@test RD.dims(pend) == (2, 1, 2)\n",
    "test_model(pend)\n",
    "\n",
    "# Quadrotor\n",
    "quad = RobotZoo.Quadrotor()\n",
    "@test RD.dims(quad) == (13, 4, 13)\n",
    "test_model(quad, tol=1e-4)\n",
    "\n",
    "# Yak Plane\n",
    "yak = RobotZoo.YakPlane(MRP{Float64})\n",
    "@test RD.dims(yak) == (12, 4, 12)\n",
    "test_model(yak, tol=1e-4)\n",
    "\n",
    "# Test other functions\n",
    "dt = 0.1\n",
    "car = RobotZoo.DubinsCar()\n",
    "n, m = RD.dims(car)\n",
    "@test zeros(car) == (zeros(n), zeros(m))\n",
    "@test zeros(Int, car)[1] isa SVector{n,Int}\n",
    "@test fill(car, 0.1) == (fill(0.1, n), fill(0.1, m))\n",
    "# @test ones(Float32,car)[2] isa SVector{m,Float32}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4b807c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.Cartpole"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9829755b",
   "metadata": {},
   "outputs": [],
   "source": [
    "collect_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b3b5cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.Acrobot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7de3bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.DubinsCar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "705eb36d",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.BicycleModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3945602d",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.Pendulum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adca1802",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.Quadrotor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aacade4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "?RobotZoo.YakPlane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "94c6f2cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x, u) = ([0.29921217555875046, 0.04857853975877857, 0.7503270511041593], [0.8862787153590471, 0.8360264269811298])\n",
      "(lower_w, upper_w, lower_b, upper_b) = ([-0.0 -0.0 -0.6043339784480347; -0.0 -0.0 0.6482826573275373; 0.0 0.0 0.0], [0.0 0.0 -0.6043339784480345; 0.0 0.0 0.6482826573275375; 0.0 0.0 0.0], [1.0983914042687994; 0.11478281812143465; 0.8360264269811298;;], [1.101730789258496; 0.11790996389349534; 0.8360264269811298;;])\n",
      "(find_bounds(upper_w, upper_b, x .- 0.1 .* ones(size(x)), x .+ 0.1 .* ones(size(x))))[2] = [0.7087160551723412; 0.6691622441807885; 0.8360264269811298;;]\n",
      "RobotDynamics.dynamics(model, x, u) = [0.6482826573275375, 0.6043339784480345, 0.8360264269811298]\n",
      "(find_bounds(lower_w, lower_b, x .- 0.1 .* ones(size(x)), x .+ 0.1 .* ones(size(x))))[1] = [0.5845098744930375; 0.5363785669432202; 0.8360264269811298;;]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3×1 Matrix{Float64}:\n",
       " 0.5845098744930375\n",
       " 0.5363785669432202\n",
       " 0.8360264269811298"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using TaylorModels\n",
    "import RobotDynamics\n",
    "model = RobotZoo.DubinsCar()\n",
    "x, u = rand(model)\n",
    "# point = IntervalBox([interval(0),interval(0)])\n",
    "# region = IntervalBox([-0.6..0.6,-0.4..0.5])\n",
    "# t1 = TaylorModelN(2,1, point,region)\n",
    "# t2 = TaylorModelN(2,1, point,region)\n",
    "# @show ff = t2^2 + t1^2\n",
    "# evaluate(ff, region)\n",
    "# @show t2\n",
    "# @show t1\n",
    "# @show typeof(t2)\n",
    "@show x,  u\n",
    "# _dim = length(x)\n",
    "myTaylorModelN(nv::Integer, ord::Integer, x0::IntervalBox{N,T}, dom::IntervalBox{N,T},vars::Vector) where {N,T} =\n",
    "    TaylorModelN(x0[nv] + vars[nv], zero(dom[1]), x0, dom)\n",
    "# var = set_variables(\"x\", numvars=dim, order=1)\n",
    "# taylor_var = [myTaylorModelN(i,1, IntervalBox([interval(0) for i in 1:dim]),IntervalBox([(0 .-0)..(0 .+0) for i in 1:dim]),var) for i in 1:dim]\n",
    "function taylor_model(center, radius;u)\n",
    "    _dim = length(center)\n",
    "    # @show dim\n",
    "    # @show center[1],(center[1]-radius[1])\n",
    "    # @show center[1], center[2], center[3]\n",
    "    # @show radius[1], radius[2], radius[3]\n",
    "    point = IntervalBox([interval(center[i]) for i in 1:_dim])\n",
    "    # region = IntervalBox([center[1]-radius[1]..center[1]+radius[1],center[2]-radius[2]..center[2]+radius[2],center[3]-radius[3]..center[3]+radius[3]])\n",
    "    region = IntervalBox([(center[i].-radius[i])..(center[i].+radius[i]) for i in 1:_dim])\n",
    "    # @show typeof(point), typeof(region)\n",
    "    # @show i\n",
    "    \n",
    "    # taylor_var.numvars = dim\n",
    "    var = set_variables(\"x\", numvars=3, order=2)\n",
    "    taylor_var = [myTaylorModelN(i,1, point,region,var) for i in 1:_dim]\n",
    "    taylor_var = [TaylorModels.TaylorModelN(i,1, point,region) for i in 1:_dim]\n",
    "    # @show typeof(taylor_var[3])\n",
    "    # @show exp(taylor_var[3])\n",
    "    # @show [u[1]*cos(taylor_var[3]),\n",
    "    #               u[1]*sin(taylor_var[3]),\n",
    "    #               u[2]]\n",
    "    dyn_x = RobotDynamics.dynamics(model, taylor_var, u)\n",
    "    # @show RobotDynamics.dynamics(model, center, u)\n",
    "    lower_w = zeros(_dim, _dim)\n",
    "    upper_w = zeros(_dim, _dim)\n",
    "    lower_b = zeros(_dim,1)\n",
    "    upper_b = zeros(_dim,1)\n",
    "    for i in 1:_dim\n",
    "        # @show dyn_x[i]\n",
    "        if isa(dyn_x[i], TaylorModelN)\n",
    "            # @show polynomial(dyn_x[i])[0], polynomial(dyn_x[i])[1][1], polynomial(dyn_x[i])[1][2],polynomial(dyn_x[i])[1][3]\n",
    "            # @show remainder(dyn_x[i]), inf(remainder(dyn_x[i])),sup(remainder(dyn_x[i])), inf(polynomial(dyn_x[i])[0][1])\n",
    "            for j in 1:_dim\n",
    "                lower_w[i, j] = inf(polynomial(dyn_x[i])[1][j])\n",
    "                upper_w[i, j] = sup(polynomial(dyn_x[i])[1][j])\n",
    "                # typeof(polynomial(dyn_x[i])[0][1]), typeof(remainder(dyn_x[i]))\n",
    "            end\n",
    "            lower_b[i, 1] = inf(polynomial(dyn_x[i])[0][1]) + inf(remainder(dyn_x[i])) - sum([lower_w[i, j] .* x[j] for j in 1:_dim])\n",
    "            upper_b[i, 1] = sup(polynomial(dyn_x[i])[0][1]) + sup(remainder(dyn_x[i])) - sum([upper_w[i, j] .* x[j] for j in 1:_dim])\n",
    "            # @show lower_w * x + lower_b\n",
    "        else\n",
    "            # @show dyn_x[i]\n",
    "            lower_b[i, 1] = dyn_x[i]\n",
    "            upper_b[i, 1] = dyn_x[i]\n",
    "        end      \n",
    "    end\n",
    "    return lower_w, upper_w, lower_b, upper_b\n",
    "end\n",
    "\n",
    "function find_bounds(w, b, lower_x, upper_x)\n",
    "    lower_x = reshape(lower_x, size(b))\n",
    "    upper_x = reshape(upper_x, size(b))\n",
    "    low = clamp.(w, 0, Inf) * lower_x + clamp.(w, -Inf, 0) * upper_x + b\n",
    "    up = clamp.(w, 0, Inf) * upper_x + clamp.(w, -Inf, 0) * lower_x + b\n",
    "    return low, up\n",
    "end\n",
    "\n",
    "# x = Vector(x)\n",
    "# u = Vector(u)\n",
    "lower_w, upper_w, lower_b, upper_b = taylor_model(x, 0.1 .* ones(size(x));u)\n",
    "@show lower_w, upper_w, lower_b, upper_b\n",
    "@show find_bounds(upper_w, upper_b, x .- 0.1 .* ones(size(x)), x .+ 0.1 .* ones(size(x)))[2]\n",
    "@show RobotDynamics.dynamics(model, x, u)\n",
    "@show find_bounds(lower_w, lower_b, x .- 0.1 .* ones(size(x)), x .+ 0.1 .* ones(size(x)))[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5fccff",
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"dataset.jl\")\n",
    "X = Hyperrectangle(low = [0, 0, -1, -1], high = [4,4, 1, 1])\n",
    "U = Hyperrectangle(low = [-1, -1], high = [1,1])\n",
    "X_unsafe = Hyperrectangle(low = [1.5, 0, -1, -1], high = [2.5,2, 1, 1])\n",
    "\n",
    "A = [0. 0 1 0;\n",
    "    0 0 0 1;\n",
    "    0 0 0 0;\n",
    "    0 0 0 0;]\n",
    "B = [0. 0;\n",
    "    0 0;\n",
    "    1 0;\n",
    "    0 1;]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "001b2182",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dfeed92",
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"affine_dynamics.jl\")\n",
    "include(\"dataset.jl\")\n",
    "include(\"visualize.jl\")\n",
    "\n",
    "# model_state = JLD2.load(\"new_best_model_ada_0_17.jld2\", \"model_state\");\n",
    "model_state = JLD2.load(\"models/big_ce_lag_e-4_0_pgd1/new_nolabel_ce_model_ada_0_pgd_20.jld2\", \"model_state\");\n",
    "\n",
    "\n",
    "# α = 1.4474390157879287\n",
    "# λ = 0.040660218997413136\n",
    "\n",
    "\n",
    "# model = Chain(\n",
    "#     Dense(4 => 8, relu),   # activation function inside layer\n",
    "#     Dense(8 => 8, relu),   # activation function inside layer\n",
    "#     Dense(8 => 4, relu),   # activation function inside layer\n",
    "#     # BatchNorm(4),\n",
    "#     Dense(4 => 2))\n",
    "# model = MyModel(); # MyModel definition must be available\n",
    "\n",
    "original_model = Chain(\n",
    "    Dense(4 => 16, relu),   # activation function inside layer\n",
    "    Dense(16 => 64, relu),   # activation function inside layer\n",
    "    Dense(64 => 16, relu),   # activation function inside layer\n",
    "    Dense(16 => 2)\n",
    ")\n",
    "\n",
    "Flux.loadmodel!(original_model, model_state);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17349f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "phi_model = Chain(\n",
    "    original_model.layers...,\n",
    "    Flux.Dense(2 => 1)\n",
    ")\n",
    "phi_model[end].weight .= [1 -1]\n",
    "phi_model[end].bias .= [0]\n",
    "# @show phi_model[end].weight\n",
    "# @show phi_model[end].bias\n",
    "# @show phi_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3af97c02",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find all the potential root region list, as hyperrectangles\n",
    "dx = 100\n",
    "dy = 100\n",
    "dvx = 100\n",
    "dvy = 100\n",
    "sub_X_list = split(X, [dx, dy, dvx, dvy])\n",
    "root_region_list = []\n",
    "for sub_X in sub_X_list\n",
    "    # @show phi_model(LazySets.center(sub_X))\n",
    "    v_list = vertices_list(sub_X)\n",
    "    # @show size(v_list[1])\n",
    "    v_mat = cat(v_list..., dims=length(size(v_list[1])) + 1)\n",
    "    # @show size(v_mat)\n",
    "    phi_v_sub = phi_model(v_mat)\n",
    "    (all(x -> x < 0, phi_v_sub) || all(x -> x > 0, phi_v_sub)) && continue\n",
    "    push!(root_region_list, sub_X)\n",
    "end\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27a8cece",
   "metadata": {},
   "outputs": [],
   "source": [
    "@show length(root_region_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "200a6c89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# X1 = Hyperrectangle(low=[0.0, 0.0], high=[1.0, 1.0])\n",
    "typeof(X)\n",
    "@show typeof(split(X, [10,5,10,10]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "467339d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# using LazySets\n",
    "\n",
    "# function split_hyperrectangle(hyperrectangle::Hyperrectangle, density::Int)\n",
    "#     # Calculate the number of intervals per dimension\n",
    "#     num_intervals = floor(Int, density^(1/dim(hyperrectangle)))\n",
    "    \n",
    "#     # Split the hyperrectangle\n",
    "#     splits = split(hyperrectangle, num_intervals)\n",
    "    \n",
    "#     return splits\n",
    "# end\n",
    "\n",
    "# # Example usage\n",
    "hyperrectangle = Hyperrectangle(low=[0.0, 0.0], high=[1.0, 1.0])\n",
    "density = 100\n",
    "@show typeof(split(hyperrectangle, [10,5]))\n",
    "# result = split_hyperrectangle(hyperrectangle, density)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302364b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "α = 0.0\n",
    "vx = 0\n",
    "vy = 0\n",
    "ax = 0\n",
    "ay = 0\n",
    "pgd_lr = 1\n",
    "pgd_num_iter = 10\n",
    "\n",
    "x = range(0, 4, length=100)\n",
    "y = range(0, 4, length=100)\n",
    "# @show size(x)\n",
    "Phi_dot_contour(x, y) = Phi_dot(original_model, A, B,x, y;α=α, vx=vx, vy=vy)[1] \n",
    "z1 = @. Phi_dot_contour(x',y)\n",
    "\n",
    "plt1 = plot_env(X, X_unsafe)\n",
    "contour!(plt1,x, y, z1, levels=10, color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "contour!(plt1,x, y, z1, levels=[0], color=:black, clabels=true, cbar=false, lw=1)\n",
    "# plot(plt1)\n",
    "\n",
    "# @show Phi_dot(model, A, B,2, 1;α=α, vx=vx, vy=vy)[1]\n",
    "\n",
    "plt2 = plot_env(X, X_unsafe)\n",
    "# z .= 1 ./ (z1 .+ 1e-18)\n",
    "heatmap!(plt2,x,y, z1)\n",
    "plot(plt1, plt2, layout = (1, 2), size=(1000,500))\n",
    "\n",
    "# h_contour(x, y) = h(model, A, B, U,x, y;α=α, vx=vx, vy=vy, lr=pgd_lr,num_iter=pgd_num_iter)[1]\n",
    "h_contour(x, y) = h(original_model, A, B, U,x, y;α=α, vx=vx, vy=vy,ax=ax, ay=ay, lr=pgd_lr,num_iter=pgd_num_iter)[1]\n",
    "z2 = @. h_contour(x', y)\n",
    "plt3 = plot_env(X, X_unsafe)\n",
    "contour!(plt3,x, y, z2, levels=10, color=:turbo, clabels=true, cbar=false, lw=1)\n",
    "contour!(plt3,x, y, z2, levels=[0], color=:black, clabels=true, cbar=false, lw=1)\n",
    "\n",
    "plt4 = plot_env(X, X_unsafe)\n",
    "heatmap!(plt4,x,y, z2)\n",
    "\n",
    "plot(plt1, plt2, plt3, plt4, layout = (2, 2), size=(1000,1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9855eb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "struct CustomModel\n",
    "  chain::Chain\n",
    "end\n",
    "\n",
    "function (m::CustomModel)(x)\n",
    "  # Arbitrary code can go here, but note that everything will be differentiated.\n",
    "  # Zygote does not allow some operations, like mutating arrays.\n",
    "\n",
    "  return m.chain(x) + x\n",
    "end\n",
    "Flux.@functor CustomModel\n",
    "chain = Chain(Dense(10, 10))\n",
    "model = CustomModel(chain)\n",
    "model(rand(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3601f905",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Flux\n",
    "using Flux: Params, param, gradient, update!\n",
    "\n",
    "# Define your custom layer\n",
    "struct MyLayer\n",
    "    W\n",
    "    b\n",
    "end\n",
    "\n",
    "# Define the constructor for MyLayer\n",
    "MyLayer(in_dim::Integer, out::Integer) = MyLayer(param(randn(out, in_dim)), param(randn(out)))\n",
    "\n",
    "# Implement Flux.Layer for MyLayer\n",
    "Flux.@functor MyLayer\n",
    "\n",
    "# Define the forward pass for MyLayer\n",
    "function (m::MyLayer)(x)\n",
    "    m.W * x .+ m.b\n",
    "end\n",
    "\n",
    "\n",
    "# Define a new model by appending your custom layer to the end of the existing model\n",
    "new_model = Chain(\n",
    "    original_model,\n",
    "    MyLayer(1, 1)\n",
    ")\n",
    "\n",
    "# Test the new model\n",
    "x = rand(2)  # Example input\n",
    "y = new_model(x) # Example output\n",
    "\n",
    "# Define a simple loss function\n",
    "loss(x, y) = sum((y .- x).^2)\n",
    "\n",
    "# Test training\n",
    "data = [(rand(2), rand(1)) for _ in 1:100]  # Example data\n",
    "opt = ADAM()\n",
    "Flux.train!(loss, Flux.params(new_model), data, opt)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.4",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
