{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611881de-174f-481d-84af-124b81c064f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Revise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd452f8b-c1a7-4f37-9f32-6bfa1447caef",
   "metadata": {},
   "outputs": [],
   "source": [
    "using BSON\n",
    "using DataDeps, MAT, MLUtils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85cf11a2-d44c-4a5d-8615-caff4ec114fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9eea4b7-b134-47c2-9557-af5085647eba",
   "metadata": {},
   "outputs": [],
   "source": [
    "using NPZ\n",
    "T = Float32\n",
    "file = matopen(\"../PDEControlGym/examples/NavierStokes/data_sac_ns_new1.mat\")\n",
    "# @show read(file, \"a\")\n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "file = matopen(\"../PDEControlGym/examples/NavierStokes/data_sac_ns_new2.mat\")\n",
    "    \n",
    "x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "# y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "# @show size(x_data)\n",
    "# @assert 1==2\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "file = matopen(\"../PDEControlGym/examples/NavierStokes/data_sac_ns_new3.mat\")\n",
    "    \n",
    "x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "close(file)\n",
    "\n",
    "# file = matopen(\"../PDEControlGym/examples/NavierStokes/data_ppo_ns4.mat\")\n",
    "    \n",
    "# x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "# y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "# close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "# file = matopen(\"../PDEControlGym/examples/reactionDiffusionPDE/data_sac_parabolic_test_dense.mat\")\n",
    "    \n",
    "# x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "# y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "\n",
    "# x_data = T.(collect(read(file, \"a\")))\n",
    "# y_data = T.(collect(read(file, \"u\")))\n",
    "# close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "target = npzread(\"../PDEControlGym/examples/NavierStokes/target.npz\")\n",
    "u_target = target[\"u\"]\n",
    "xu_target = ones(size(x_data))\n",
    "yu_target = ones(size(y_data))\n",
    "\n",
    "yu_target[1, :] .= u_target[:,end-1,11]\n",
    "# @show yu_target[1, 1:10]\n",
    "yu_target[:, :] .= yu_target[1:1, :]\n",
    "# @show yu_target[1, 1:10]\n",
    "# @show yu_target[2, 1:10]\n",
    "# @show keys(target)\n",
    "# @show target\n",
    "# @show size(target[\"v\"])\n",
    "# @show size(target[\"u\"])\n",
    "# plt.plot(_a[:,-2,10,0])\n",
    "# # plt.plot(u_target[:, -2,10])\n",
    "    # xs_opt.append(env.U[:,-1,1,0])\n",
    "    # ys_opt.append(env.U[:,-2,10,0])\n",
    "# @assert 1==2\n",
    "# u_target = np.load('target.npz')['u']\n",
    "# v_target = np.load('target.npz')['v']\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 0.12\n",
    "# threshold =  1.1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]-yu_target[i,j]) > threshold\n",
    "            \n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    # @show abs(y_data[i,end]-yu_target[i,end])\n",
    "    if abs(y_data[i,end]-yu_target[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    # elseif any(y_data[i,end:end] .> threshold)\n",
    "    #     # pf but not used for training\n",
    "    #     pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_sac_ns_new_abs_0.12.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show not_pf, pf_training, 10000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e272302f-54c6-4dc6-92db-237604c502df",
   "metadata": {},
   "outputs": [],
   "source": [
    "using NPZ\n",
    "T = Float32\n",
    "file = matopen(\"../PDEControlGym/examples/NavierStokes/data_opt_ns__0init_1.mat\")\n",
    "# @show read(file, \"a\")\n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "file = matopen(\"../PDEControlGym/examples/NavierStokes/data_opt_ns__0init_2.mat\")\n",
    "    \n",
    "x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "close(file)\n",
    "# @show size(x_data)\n",
    "# @assert 1==2\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "# file = matopen(\"../PDEControlGym/examples/NavierStokes/data_sac_ns_new3.mat\")\n",
    "    \n",
    "# x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "# y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "# close(file)\n",
    "\n",
    "# # file = matopen(\"../PDEControlGym/examples/NavierStokes/data_ppo_ns4.mat\")\n",
    "    \n",
    "# x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "# y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "# close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "# file = matopen(\"../PDEControlGym/examples/reactionDiffusionPDE/data_sac_parabolic_test_dense.mat\")\n",
    "    \n",
    "# x_data = cat(x_data, T.(collect(read(file, \"a\"))),dims=1)\n",
    "# y_data = cat(y_data, T.(collect(read(file, \"u\"))),dims=1)\n",
    "\n",
    "# x_data = T.(collect(read(file, \"a\")))\n",
    "# y_data = T.(collect(read(file, \"u\")))\n",
    "# close(file)\n",
    "# @show (x_data)\n",
    "# @show (y_data[1,:])\n",
    "\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "target = npzread(\"../PDEControlGym/examples/NavierStokes/target.npz\")\n",
    "u_target = target[\"u\"]\n",
    "xu_target = ones(size(x_data))\n",
    "yu_target = ones(size(y_data))\n",
    "\n",
    "yu_target[1, :] .= u_target[:,end-1,11]\n",
    "# @show yu_target[1, 1:10]\n",
    "yu_target[:, :] .= yu_target[1:1, :]\n",
    "# @show yu_target[1, 1:10]\n",
    "# @show yu_target[2, 1:10]\n",
    "# @show keys(target)\n",
    "# @show target\n",
    "# @show size(target[\"v\"])\n",
    "# @show size(target[\"u\"])\n",
    "# plt.plot(_a[:,-2,10,0])\n",
    "# # plt.plot(u_target[:, -2,10])\n",
    "    # xs_opt.append(env.U[:,-1,1,0])\n",
    "    # ys_opt.append(env.U[:,-2,10,0])\n",
    "# @assert 1==2\n",
    "# u_target = np.load('target.npz')['u']\n",
    "# v_target = np.load('target.npz')['v']\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 0.145\n",
    "# threshold =  1.1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]-yu_target[i,j]) > threshold\n",
    "            \n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    # @show  safe_labels[i,:]\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    # @show  safe_labels[i,:]\n",
    "    # @show abs(y_data[i,end]-yu_target[i,end])\n",
    "    if abs(y_data[i,end]-yu_target[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    # elseif any(y_data[i,end:end] .> threshold)\n",
    "    #     # pf but not used for training\n",
    "    #     pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_opt_ns_abs_0.145.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show not_pf, pf_training, 10000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c14eb8-4c36-4861-8774-9d80c3ceecfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# using BSON\n",
    "# using DataDeps, MAT, MLUtils\n",
    "\n",
    "\n",
    "\n",
    "# T = Float32\n",
    "# file = matopen(\"../PDEControlGym/examples/reactionDiffusionPDE/data_bcks_parabolic.mat\")\n",
    "    \n",
    "# x_data = T.(collect(read(file, \"a\")))\n",
    "# y_data = T.(collect(read(file, \"u\")))\n",
    "# close(file)\n",
    "# @show size(x_data)\n",
    "# @show size(y_data)\n",
    "\n",
    "# # file = matopen(\"matfile.mat\", \"w\")\n",
    "# # write(file, \"varname\", variable)\n",
    "# # close(file)\n",
    "\n",
    "# threshold = 1\n",
    "# # pf: 1, not pf: 0\n",
    "# pf_labels = zeros(size(x_data))\n",
    "# # safe: 1, not safe: 0\n",
    "# safe_labels = -ones(size(x_data))\n",
    "# for i in 1:size(x_data, 1)\n",
    "#     for j in 1:size(x_data, 2)\n",
    "#         if y_data[i,j] > threshold\n",
    "#             safe_labels[i,j] = 0\n",
    "#         end\n",
    "#     end\n",
    "\n",
    "#     for j in 1:size(x_data, 2)\n",
    "#         if safe_labels[i,end+1-j] == -1\n",
    "#             safe_labels[i,end+1-j] = 1\n",
    "#         else\n",
    "#             break\n",
    "#         end\n",
    "#     end\n",
    "#     # @show safe_labels[i]\n",
    "#     if any(safe_labels[i,10:end] .== -1)\n",
    "#         pf_labels[i, :] .= 0\n",
    "#         # @show i\n",
    "#     else\n",
    "#         pf_labels[i, :] .= 1\n",
    "#         # safe_labels[i, 35:end] .= -2\n",
    "#         # safe_labels[i, 1:10] .= -2\n",
    "#         # if any(safe_labels[i,:] .== 0)\n",
    "#         #     @show (y_data[i,1])\n",
    "#         # end\n",
    "#     end\n",
    "#     # break\n",
    "# end\n",
    "\n",
    "# matwrite(\"data_bcks_hyperbolic_1_new_10.mat\", Dict(\n",
    "# \t\"a\" => x_data,\n",
    "# \t\"u\" => y_data,\n",
    "#     \"pf\" => pf_labels,\n",
    "#     \"safe\" => safe_labels\n",
    "# ))\n",
    "# @show sum(pf_labels[:, 1])\n",
    "# @show size(findall(x->x==0, safe_labels))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbfae270-0e03-4904-9f9a-9c7aee680570",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_data[4,35:end]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04caa769-1478-4146-801d-a6be00784767",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if y_data[i,end] > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    else\n",
    "        for j in 1:size(x_data, 2)\n",
    "            if safe_labels[i,end+1-j] == -1\n",
    "                safe_labels[i,end+1-j] = 1\n",
    "            else\n",
    "                break\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_minus.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a5d8c29-3765-4363-97f7-742ad4fcceec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "472347c0-cae2-47db-aa59-7d255c07ddcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "898861f7-4bc3-4628-b963-41e699555a3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_sac_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]) > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if abs(y_data[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    # else\n",
    "    #     for j in 1:size(x_data, 2)\n",
    "    #         if safe_labels[i,end+1-j] == -1\n",
    "    #             safe_labels[i,end+1-j] = 1\n",
    "    #         else\n",
    "    #             break\n",
    "    #         end\n",
    "    #     end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_sac_hyperbolic_1_abs.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06bd9416-98f0-4248-9578-6c08957c75ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]) > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    if abs(y_data[i,end]) > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "    else\n",
    "        for j in 1:size(x_data, 2)\n",
    "            if safe_labels[i,end+1-j] == -1\n",
    "                safe_labels[i,end+1-j] = 1\n",
    "            else\n",
    "                break\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_abs_minus.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show sum(pf_labels[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6b0b796-c4d2-4e7a-9a89-1d96950f4307",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_ppo_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "bias = 0\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if abs(y_data[i,j]+bias) > threshold+bias\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    # if abs(y_data[i,end]) > threshold\n",
    "    #     pf_labels[i, :] .= 0\n",
    "    # else\n",
    "\n",
    "    # end\n",
    "    if abs(y_data[i,end]+bias) > threshold+bias\n",
    "        # not pf\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    elseif any(abs.(y_data[i,end-5:end] .+bias) .> threshold+bias)\n",
    "        # pf but not used for training\n",
    "        pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "    \n",
    "    # i > 2 && break\n",
    "end\n",
    "\n",
    "# matwrite(\"data_sac_hyperbolic_1_abs_minus_5pf.mat\", Dict(\n",
    "# \t\"a\" => x_data,\n",
    "# \t\"u\" => y_data,\n",
    "#     \"pf\" => pf_labels,\n",
    "#     \"safe\" => safe_labels\n",
    "# ))\n",
    "@show not_pf, pf_training, 50000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e69fd6c-0c4c-42b3-81fe-4d511320757c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @show a = rand(10)\n",
    "# any(abs.(a .-0.5) .> 0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06ddcef-e107-4b27-8102-c7117f360fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "T = Float32\n",
    "file = matopen(\"/data/code/PDEControlGym/examples/transportPDE/data_bcks_hyperbolic.mat\")\n",
    "    \n",
    "x_data = T.(collect(read(file, \"a\")))\n",
    "y_data = T.(collect(read(file, \"u\")))\n",
    "close(file)\n",
    "@show size(x_data)\n",
    "@show size(y_data)\n",
    "\n",
    "# file = matopen(\"matfile.mat\", \"w\")\n",
    "# write(file, \"varname\", variable)\n",
    "# close(file)\n",
    "\n",
    "threshold = 1\n",
    "# pf: 1, not pf: 0\n",
    "pf_labels = ones(size(x_data))\n",
    "# safe: 1, not safe: 0\n",
    "safe_labels = -ones(size(x_data))\n",
    "not_pf = 0\n",
    "pf_training = 0\n",
    "for i in 1:size(x_data, 1)\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if y_data[i,j] > threshold\n",
    "            safe_labels[i,j] = 0\n",
    "        end\n",
    "    end\n",
    "    for j in 1:size(x_data, 2)\n",
    "        if safe_labels[i,end+1-j] == -1\n",
    "            safe_labels[i,end+1-j] = 1\n",
    "        else\n",
    "            break\n",
    "        end\n",
    "    end\n",
    "    if y_data[i,end] > threshold\n",
    "        pf_labels[i, :] .= 0\n",
    "        not_pf += 1\n",
    "        # @show \"not pf\", y_data[i,end-10:end]\n",
    "    elseif any(y_data[i,end-5:end] .> threshold)\n",
    "        # pf but not used for training\n",
    "        pf_labels[i, :] .= 2\n",
    "        # @show \"pf but not used for training\",y_data[i,end-10:end]\n",
    "    else\n",
    "        # pf and used for training\n",
    "        pf_labels[i, :] .= 1\n",
    "        pf_training += 1\n",
    "        # @show \"pf and used for training\",y_data[i,end-10:end]\n",
    "    end\n",
    "end\n",
    "\n",
    "matwrite(\"data_bcks_hyperbolic_1_minus_pf5.mat\", Dict(\n",
    "\t\"a\" => x_data,\n",
    "\t\"u\" => y_data,\n",
    "    \"pf\" => pf_labels,\n",
    "    \"safe\" => safe_labels\n",
    "))\n",
    "@show not_pf, pf_training, 50000-not_pf-pf_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26952d60-47d3-438b-bdb4-08351d12e57a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.9.4",
   "language": "julia",
   "name": "julia-1.9"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
