{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "# from utils import plot_mesh\n",
    "from dataset import CFDGraphsDataset\n",
    "from plotter import plot_mesh\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "from scipy.spatial import Delaunay, ConvexHull\n",
    "from torch_geometric.data import Data\n",
    "import matplotlib.tri as mtri\n",
    "    \n",
    "filename = 'testing.ipynb'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir = './models/airfoil_gabi_1/'\n",
    "dir_plt = dir+'plt/'\n",
    "os.makedirs(dir, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = CFDGraphsDataset(r'train_dataset.zip',\n",
    "                            random_masking=False, zero_augmentation=False, sdf_input=True,\n",
    "                            airfoil_coverage=0.1)\n",
    "# print(len(dataset))\n",
    "# print(dataset.num_node_features)\n",
    "# print(dataset.num_node_output_features)\n",
    "# print(dataset.num_edge_features)\n",
    "# print(dataset.num_glob_features)\n",
    "# print(dataset[0].node_feat_labels)\n",
    "# print(dataset[0].globals)\n",
    "# print(dataset[0].globals_y)\n",
    "\n",
    "print(dataset[0])\n",
    "print(np.sum(np.asarray(dataset[0].known_feature_mask)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# graph = dataset[0]\n",
    "\n",
    "# # data.x is feature\n",
    "# # copy the node target values [pressure, x-velocity, y-velocity] without the signed distance function [idx 3]\n",
    "# # data.y = data.x[:, :3].detach().clone()\n",
    "\n",
    "# print(torch.sum(dataset[0].node_type))\n",
    "\n",
    "# airf_nodes = graph.node_type == 1.0\n",
    "# print(airf_nodes)\n",
    "\n",
    "\n",
    "# plt.scatter(graph.pos[:,0], graph.pos[:,1], s=1)\n",
    "# plt.scatter(graph.pos[airf_nodes,0], graph.pos[airf_nodes, 1], s=1)\n",
    "# plt.show()\n",
    "\n",
    "# print(graph.pos.shape)\n",
    "# print(graph.x.shape)\n",
    "# print(graph.edge_index.shape)\n",
    "\n",
    "\n",
    "\n",
    "def less_first(a, b):\n",
    "    return [a,b] if a < b else [b,a]\n",
    "def delaunay2edges(tri):\n",
    "\n",
    "    list_of_edges = []\n",
    "\n",
    "    for triangle in tri.simplices:\n",
    "        for e1, e2 in [[0,1],[1,2],[2,0]]: # for all edges of triangle\n",
    "            list_of_edges.append(less_first(triangle[e1],triangle[e2])) # always lesser index first\n",
    "\n",
    "    array_of_edges = np.unique(list_of_edges, axis=0) # remove duplicates\n",
    "\n",
    "    list_of_lengths = []\n",
    "\n",
    "    for p1,p2 in array_of_edges:\n",
    "        x1, y1 = tri.points[p1]\n",
    "        x2, y2 = tri.points[p2]\n",
    "        list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)\n",
    "\n",
    "    array_of_lengths = np.sqrt(np.array(list_of_lengths))\n",
    "\n",
    "    return array_of_edges, array_of_lengths\n",
    "\n",
    "import numpy as np\n",
    "from scipy.spatial import cKDTree\n",
    "\n",
    "def remove_close_points(points, u, tolerance):\n",
    "    if len(points) == 0:\n",
    "        return points\n",
    "\n",
    "    tree = cKDTree(points)\n",
    "    to_keep = np.ones(len(points), dtype=bool)\n",
    "\n",
    "    for i in range(len(points)):\n",
    "        if not to_keep[i]:\n",
    "            continue\n",
    "        # Find neighbors within the tolerance (excluding self)\n",
    "        indices = tree.query_ball_point(points[i], r=tolerance)\n",
    "        indices.remove(i)\n",
    "        to_keep[indices] = False  # Remove all nearby points\n",
    "\n",
    "    return to_keep\n",
    "\n",
    "\n",
    "\n",
    "def constrained_delaunay(points, is_special):\n",
    "\n",
    "    points = np.asarray(points)\n",
    "    is_special = np.asarray(is_special)\n",
    "\n",
    "    # Perform full Delaunay triangulation\n",
    "    delaunay = Delaunay(points)\n",
    "    all_faces = delaunay.simplices\n",
    "\n",
    "    # Filter triangles: remove those with 3 special nodes\n",
    "    filtered_faces = []\n",
    "    for tri in all_faces:\n",
    "        if np.sum(is_special[tri]) < 3:\n",
    "            filtered_faces.append(tri)\n",
    "    faces = np.array(filtered_faces)\n",
    "\n",
    "    # Build Triangulation object\n",
    "    # triang = mtri.Triangulation(points[:, 0], points[:, 1], faces)\n",
    "\n",
    "    # Extract unique edges\n",
    "    edge_set = set()\n",
    "    for tri in faces:\n",
    "        for edge in [(tri[0], tri[1]), (tri[1], tri[2]), (tri[2], tri[0])]:\n",
    "            edge_set.add(tuple(sorted(edge)))\n",
    "    edge_index = np.array(list(edge_set))\n",
    "\n",
    "    # Compute edge lengths\n",
    "    edge_lengths = np.linalg.norm(points[edge_index[:, 0]] - points[edge_index[:, 1]], axis=1)\n",
    "\n",
    "    return edge_index, edge_lengths, faces\n",
    "\n",
    "\n",
    "def make_new_graph(points, u, airf_nodes):\n",
    "    # tol = 0.5e-2\n",
    "    tol = 1e-2\n",
    "    to_keep_not_airfoil = remove_close_points(points[~airf_nodes], u[~airf_nodes], tolerance=tol)\n",
    "    \n",
    "    points_notairf = points[~airf_nodes][to_keep_not_airfoil]\n",
    "    u_notairf = u[~airf_nodes][to_keep_not_airfoil]\n",
    "    \n",
    "    to_keep_airf = remove_close_points(points[airf_nodes], u[airf_nodes], tolerance=tol/5)\n",
    "    \n",
    "    points_airf = points[airf_nodes][to_keep_airf]\n",
    "    u_airf = u[airf_nodes][to_keep_airf]\n",
    "    \n",
    "    points = np.concatenate((points_notairf, points_airf), axis=0)\n",
    "    u = np.concatenate((u_notairf, u_airf), axis=0)\n",
    "    to_keep = np.concatenate((to_keep_not_airfoil, to_keep_airf), axis=0)\n",
    "    \n",
    "    airf_nodes_keep = airf_nodes[to_keep]\n",
    "    \n",
    "    u = torch.nan_to_num(torch.tensor(u).to(torch.float), nan=0.0, posinf=0.0, neginf=0.0)\n",
    "\n",
    "    edge_index, edge_lengths, faces = constrained_delaunay(points, airf_nodes_keep)\n",
    "    \n",
    "    data = Data(\n",
    "        pos=torch.tensor(points).to(dtype=torch.float), # Node positions\n",
    "        edge_index=torch.tensor(edge_index.T).to(dtype=torch.int64), # Use edge_index from Laplacian\n",
    "        edge_attr=torch.tensor(edge_lengths).to(dtype=torch.float), # Edge lengths\n",
    "        face=torch.tensor(faces.T).to(dtype=torch.int64),\n",
    "        x=torch.tensor(u).to(dtype=torch.float), # Node features (steady-state temperatures)\n",
    "        airf_nodes = torch.tensor(airf_nodes_keep).to(torch.float)\n",
    "    )\n",
    "    return data\n",
    "\n",
    "# from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "graph = dataset[0]\n",
    "del_graph = make_new_graph(graph.pos.numpy(), graph.y[:,:3],  graph.node_type == 1.0)\n",
    "\n",
    "# # plot_slngraph(del_graph, del_graph.x[:,:1].cpu().detach().numpy())\n",
    "\n",
    "# print(del_graph.pos.shape)\n",
    "# print(del_graph.x.shape)\n",
    "\n",
    "# print(torch.sum(del_graph.x, dim=0))  # Check if the sum of node features is correct\n",
    "\n",
    "\n",
    "# [pressure, x-velocity, y-velocity]\n",
    "\n",
    "def plot_graph(graph, u, filename='graph'):\n",
    "    pos = graph.pos.detach().cpu().numpy()\n",
    "    face = graph.face.detach().cpu().numpy()\n",
    "    contourf = plt.tricontourf(pos[:, 0], pos[:, 1], face.T, u , levels=20, cmap='viridis')\n",
    "    # tri_plot = mtri.Triangulation(pos[:, 0], pos[:, 1], triangles=face.T)\n",
    "    triang = mtri.Triangulation(pos[:, 0], pos[:, 1], face.T)\n",
    "    plt.triplot(triang, color='black', alpha=0.5, linewidth=0.01)\n",
    "    # plt.scatter(del_graph.pos[airf_nodes, 0], del_graph.pos[airf_nodes, 1], s=0.01, c='red', alpha=1)\n",
    "    plt.colorbar(contourf)\n",
    "    plt.savefig(filename+'.pdf')\n",
    "    # plt.close()\n",
    "\n",
    "plot_graph(del_graph,  del_graph.x[:,0], filename=dir_plt+'airfoil_p0')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "thin_dataset = []\n",
    "for i in range(800):\n",
    "    print(f'thinning graph {i}')\n",
    "    thin_dataset.append(make_new_graph(dataset[i].pos.numpy(), dataset[i].y[:,:3], dataset[i].node_type == 1.0))\n",
    "    # gc.collect()\n",
    "    \n",
    "import pickle\n",
    "with open(dir+'airfoil_thinned_train_1.pkl', 'wb') as f:\n",
    "    pickle.dump(thin_dataset, f)\n",
    "    \n",
    "del thin_dataset\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "thin_dataset = []\n",
    "for i in range(800, len(dataset)-800):\n",
    "    print(f'thinning graph {i}')\n",
    "    thin_dataset.append(make_new_graph(dataset[i].pos.numpy(), dataset[i].y[:,:3], dataset[i].node_type == 1.0))\n",
    "    # gc.collect()\n",
    "    \n",
    "import pickle\n",
    "with open(dir+'airfoil_thinned_train_2.pkl', 'wb') as f:\n",
    "    pickle.dump(thin_dataset, f)\n",
    "    \n",
    "del thin_dataset\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = dataset[0]\n",
    "del_graph = make_new_graph(graph.pos.numpy(), graph.y[:,:3],  graph.node_type == 1.0)\n",
    "# airf_nodes = del_graph.node_type == 1.0\n",
    "# measured_airf_nodes = del_graph.airf_nodes.to(torch.bool) * (del_graph.pos[:, 0] < 0.5) #chord len = 1\n",
    "# measured_airf_nodes[::10] = True\n",
    "\n",
    "# mask = torch.zeros(del_graph.pos.shape[0])\n",
    "# mask[::50] = 1.0\n",
    "# mask = mask.to(torch.bool)\n",
    "\n",
    "# mask = torch.zeros(del_graph.pos.shape[0])\n",
    "# idx_rnd = np.random.choice(range(0, mask.shape[0]), size=(10,), replace=False)\n",
    "# mask[idx_rnd] = 1.0\n",
    "# mask = mask.to(torch.bool)\n",
    "\n",
    "# sort_posx = torch.argsort(del_graph.pos[:, 0])\n",
    "\n",
    "# measured_airf_nodes = del_graph.airf_nodes.to(torch.bool) * mask\n",
    "# measured_airf_nodes = del_graph.airf_nodes.to(torch.bool)\n",
    "\n",
    "# Get indices where value is 1\n",
    "indices = np.where(del_graph.airf_nodes == 1)[0]\n",
    "measured_airf_nodes = np.random.choice(indices, size=10, replace=False)\n",
    "measured_airf_nodes = torch.tensor(measured_airf_nodes)\n",
    "\n",
    "# print(torch.sum(measured_airf_nodes))\n",
    "airfoil_pressure = del_graph.x[measured_airf_nodes, 0]\n",
    "\n",
    "\n",
    "pos = del_graph.pos.detach().cpu().numpy()\n",
    "face = del_graph.face.detach().cpu().numpy()\n",
    "contourf = plt.tricontourf(pos[:, 0], pos[:, 1], face.T, del_graph.x[:,0] , levels=20, cmap='viridis')\n",
    "# tri_plot = mtri.Triangulation(pos[:, 0], pos[:, 1], triangles=face.T)\n",
    "triang = mtri.Triangulation(pos[:, 0], pos[:, 1], face.T)\n",
    "plt.triplot(triang, color='black', alpha=0.5, linewidth=0.01)\n",
    "# plt.scatter(del_graph.pos[airf_nodes, 0], del_graph.pos[airf_nodes, 1], s=0.01, c='red', alpha=1)\n",
    "plt.scatter(pos[measured_airf_nodes, 0], pos[measured_airf_nodes, 1],\n",
    "            s=1, c=airfoil_pressure.detach().cpu().numpy(), cmap='viridis', alpha=1)\n",
    "plt.colorbar(contourf)\n",
    "\n",
    "plt.savefig('measured_locs'+'.pdf')\n",
    "\n",
    "# ObsIdx = np.random.choice(range(del_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "# ObsIdx = torch.tensor(ObsIdx)\n",
    "# y_n = (del_graph.x[:,0]).reshape(-1,)[ObsIdx]  #+ sigma_tc * torch.randn(ObsIdx.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
