{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611881de-174f-481d-84af-124b81c064f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd452f8b-c1a7-4f37-9f32-6bfa1447caef",
   "metadata": {},
   "outputs": [],
   "source": [
    "using BSON\n",
    "using DataDeps, MAT, MLUtils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c14eb8-4c36-4861-8774-9d80c3ceecfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "using BSON\n",
    "using DataDeps, MAT, MLUtils\n",
    "\n",
    "\n",
    "\n",
    "T = Float32\n",
    "file = matopen(\"../PDEControlGym/examples/transportPDE/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = zeros(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    # @show safe_labels[i]\n",
    "    if any(safe_labels[i,10:end] .== -1)\n",
    "        pf_labels[i, :] .= 0\n",
    "        # @show i\n",
    "    else\n",
    "        pf_labels[i, :] .= 1\n",
    "        # safe_labels[i, 35:end] .= -2\n",
    "        # safe_labels[i, 1:10] .= -2\n",
    "        # if any(safe_labels[i,:] .== 0)\n",
    "        #     @show (y_data[i,1])\n",
    "        # end\n",
    "    end\n",
    "    # break\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_new_10.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])\n",
    "@show size(findall(x->x==0, safe_labels))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbfae270-0e03-4904-9f9a-9c7aee680570",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_data[4,35:end]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04caa769-1478-4146-801d-a6be00784767",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if y_data[i,end] > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    else\n",
    "        for j in 1:size(x_data, 2)\n",
    "            if safe_labels[i,end+1-j] == -1\n",
    "                safe_labels[i,end+1-j] = 1\n",
    "            else\n",
    "                break\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_minus.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a5d8c29-3765-4363-97f7-742ad4fcceec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "472347c0-cae2-47db-aa59-7d255c07ddcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "898861f7-4bc3-4628-b963-41e699555a3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_sac_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]) > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if abs(y_data[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    # else\n",
    "    #     for j in 1:size(x_data, 2)\n",
    "    #         if safe_labels[i,end+1-j] == -1\n",
    "    #             safe_labels[i,end+1-j] = 1\n",
    "    #         else\n",
    "    #             break\n",
    "    #         end\n",
    "    #     end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_sac_hyperbolic_1_abs.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06bd9416-98f0-4248-9578-6c08957c75ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]) > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if abs(y_data[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    else\n",
    "        for j in 1:size(x_data, 2)\n",
    "            if safe_labels[i,end+1-j] == -1\n",
    "                safe_labels[i,end+1-j] = 1\n",
    "            else\n",
    "                break\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_abs_minus.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6b0b796-c4d2-4e7a-9a89-1d96950f4307",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_ppo_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "bias = 0\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]+bias) > threshold+bias\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    # if abs(y_data[i,end]) > threshold\n",
    "    #     pf_labels[i, :] .= 0\n",
    "    # else\n",
    "\n",
    "    # end\n",
    "    if abs(y_data[i,end]+bias) > threshold+bias\n",
    "        # not pf\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    elseif any(abs.(y_data[i,end-5:end] .+bias) .> threshold+bias)\n",
    "        # pf but not used for training\n",
    "        pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "    \n",
    "    # i > 2 && break\n",
    "end\n",
    "\n",
    "# matwrite(\"data_sac_hyperbolic_1_abs_minus_5pf.mat\", Dict(\n",
    "# \t\"a\" => x_data,\n",
    "# \t\"u\" => y_data,\n",
    "#     \"pf\" => pf_labels,\n",
    "#     \"safe\" => safe_labels\n",
    "# ))\n",
    "@show not_pf, pf_training, 50000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e69fd6c-0c4c-42b3-81fe-4d511320757c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @show a = rand(10)\n",
    "# any(abs.(a .-0.5) .> 0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06ddcef-e107-4b27-8102-c7117f360fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    if y_data[i,end] > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    elseif any(y_data[i,end-5:end] .> threshold)\n",
    "        # pf but not used for training\n",
    "        pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_minus_pf5.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show not_pf, pf_training, 50000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26952d60-47d3-438b-bdb4-08351d12e57a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.4",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
