{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch.multiprocessing as mp\n",
    "# mp.set_start_method('spawn')\n",
    "\n",
    "import torch\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial import Delaunay\n",
    "from scipy.sparse import lil_matrix\n",
    "from scipy.sparse.linalg import spsolve\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "\n",
    "from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "\n",
    "\n",
    "device = 'cuda' \n",
    "\n",
    "filename = 'supervised.ipynb'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir_data = './models/RectHeat_GABI_5/'\n",
    "dir_save = './models/RectHeat_SUPER_2_VARNOISE/'\n",
    "dir_plt = dir_save+'plt/'\n",
    "os.makedirs(dir_save, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir_save}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir_data+'data_list.pkl', 'rb') as f:\n",
    "    data_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    plot_slngraph(data_list[i], data_list[i].x[:,0].cpu().detach().numpy(),\n",
    "                  save=f\"models/RectHeat_GABI_5/plt/data_{i}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data = next(iter(loader)).to(device)\n",
    "# print(data.num_node_features)\n",
    "\n",
    "dim_z = 100\n",
    "# dim_z = 50\n",
    "dim_y = 1\n",
    "dim_u = 1\n",
    "dim_x = 2\n",
    "dim_node_features = dim_x + dim_u\n",
    "\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_y + 2)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, graph):\n",
    "                \n",
    "        x = torch.cat([graph.x,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcn = GCN().to(device)\n",
    "\n",
    "# from torch.optim.lr_scheduler import ExponentialLR\n",
    "optimizer = torch.optim.Adam(gcn.parameters(), lr=0.001)\n",
    "# scheduler = ExponentialLR(optimizer, gamma=0.95)\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcn))\n",
    "print(count_parameters(gcn))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rnd_stddev():\n",
    "    return torch.exp(torch.randn(1,) - 4.) + 1e-3\n",
    "\n",
    "samples = torch.stack([get_rnd_stddev() for _ in range(10_000)]).reshape(-1,)\n",
    "plt.hist(samples, bins=100, density=True)\n",
    "plt.show()\n",
    "plt.close()\n",
    "print(torch.mean(samples), torch.std(samples))\n",
    "\n",
    "plt.hist(torch.log(samples-1e-3)+4., bins=50, density=True)\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_obs = 10\n",
    "sigma_tc = torch.tensor([0.01]).to(device)\n",
    "\n",
    "# def loss_function(graph):\n",
    "\n",
    "#         graph_copy = graph.clone()\n",
    "#         perm = torch.randperm(graph_copy.x.shape[0])\n",
    "#         idx = perm[:n_obs * (torch.max(graph_copy.batch)+1)]\n",
    "#         graph_copy.x *= 0.\n",
    "#         y = graph.x[idx].clone()\n",
    "#         y += torch.randn_like(y) * sigma_tc\n",
    "#         graph_copy.x[idx] = y\n",
    "        \n",
    "#         u = gcn(graph_copy)\n",
    "#         loss = torch.mean((graph.x - u)**2.)\n",
    "  \n",
    "#         return loss, (None)\n",
    "\n",
    "def loss_function(graph):\n",
    "\n",
    "    graph_copy = graph.clone()\n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        k = min(n_obs, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0)\n",
    "\n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx].clone()\n",
    "    sigma_tc = get_rnd_stddev().to(device)\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx] = y\n",
    "\n",
    "    u = gcn(graph_copy)\n",
    "    loss = torch.mean((graph.x - u)**2.)\n",
    "\n",
    "    return loss, (None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "loader_train = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "data_train = next(iter(loader_train))\n",
    "\n",
    "gcn.train()\n",
    "\n",
    "LOSS = []\n",
    "import time\n",
    "time_train_start = time.time()\n",
    "# for epoch in range(100_000):\n",
    "for epoch in range(20_000):\n",
    "\n",
    "    start_time_b = time.time()\n",
    "    for data in loader_train:\n",
    "        optimizer.zero_grad()\n",
    "        data.to(device)\n",
    "        loss, aux = loss_function(data)\n",
    "        LOSS.append(loss.cpu().detach().numpy())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    # scheduler.step()\n",
    "    print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        \n",
    "    print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcn.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "torch.save(gcn.state_dict(), dir_save+'gcn.model')\n",
    "torch.save(optimizer.state_dict(), dir_save+'gcn.opt')\n",
    "\n",
    "loss_dict = {'LOSS': LOSS, 'time_train':time_train,}\n",
    "import pickle\n",
    "with open(dir_save+'loss_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD MODELS\n",
    "gcn.load_state_dict(torch.load(dir_save+'gcn.model'))\n",
    "gcn.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir_data+'data_test_list.pkl', 'rb') as f:\n",
    "    data_test_list = pickle.load(f)\n",
    "data_test_list = data_test_list[:1000]\n",
    "N_test = len(data_test_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_text = 0\n",
    "graph_test = data_test_list[idx_text]\n",
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "# u_abc_decode = torch.vmap(gcnD, in_dims=(0, None), chunk_size=100)(z_grid[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "n_obs = 10\n",
    "sigma = torch.FloatTensor([0.01])\n",
    "\n",
    "ObsIdx = np.random.choice(range(graph_test.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "ObsIdx = torch.tensor(ObsIdx)\n",
    "y_n = (graph_test.x).reshape(-1,)[ObsIdx] + sigma * torch.randn(ObsIdx.shape[0])\n",
    "y_n = y_n.to(device)\n",
    "\n",
    "graph_copy = graph_loaded.clone()\n",
    "perm = torch.randperm(graph_copy.x.shape[0])\n",
    "idx = perm[:n_obs * (torch.max(graph_copy.batch)+1)]\n",
    "graph_copy.x *= 0.\n",
    "y = graph_loaded.x[idx].clone()\n",
    "sigma_tc = 1e-2\n",
    "y += torch.randn_like(y) * sigma_tc\n",
    "graph_copy.x[idx] = y\n",
    "\n",
    "u_decode = gcn(graph_copy)\n",
    "\n",
    "print('data')\n",
    "plot_slngraph(graph_test, graph_test.x[:,0].cpu().detach().numpy(),\n",
    "              ObsIdx=ObsIdx,\n",
    "              save=dir_save+'plt/true.png')\n",
    "\n",
    "print('u_min_norm')\n",
    "plot_slngraph(graph_test, u_decode[:,0].detach().cpu().numpy(),\n",
    "              save=dir_save+'plt/mode_pred.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "# Number of runs\n",
    "num_runs = 100\n",
    "n_obs = 10\n",
    "sigma = 0.01\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "for i in range(num_runs):\n",
    "    # Clone test graph\n",
    "    graph_copy = graph_loaded.clone()\n",
    "\n",
    "    # Random observation indices\n",
    "    ObsIdx = np.random.choice(range(graph_copy.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx, device=graph_copy.x.device)\n",
    "\n",
    "    # Add noise to true values at observation points\n",
    "    y_obs = graph_loaded.x[ObsIdx].clone()\n",
    "    sigma = get_rnd_stddev().to(device)\n",
    "    y_noisy = y_obs + sigma * torch.randn_like(y_obs)\n",
    "\n",
    "    # Zero out all node features\n",
    "    graph_copy.x *= 0.0\n",
    "\n",
    "    # Place noisy observations at sampled locations\n",
    "    graph_copy.x[ObsIdx] = y_noisy\n",
    "\n",
    "    # Decode/predict with your GCN\n",
    "    u_decode = gcn(graph_copy)\n",
    "\n",
    "    # Compute MAE for this run\n",
    "    true = graph_loaded.x.reshape(-1)\n",
    "    pred = u_decode.reshape(-1)\n",
    "    mae = torch.mean(torch.abs(pred - true)).item()\n",
    "    mae_list.append(mae)\n",
    "\n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "print(f\"Mean MAE over {num_runs} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/100:.5f} seconds\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "n_obs = 10\n",
    "# sigma = 0.01\n",
    "\n",
    "\n",
    "mae_list = []\n",
    "\n",
    "# Start timer\n",
    "start_time = time.time()\n",
    "\n",
    "loader_test = DataLoader(data_test_list, batch_size=1, shuffle=True)\n",
    "# loader_test = DataLoader([data_test_list[4]], batch_size=1, shuffle=True)\n",
    "\n",
    "for graph in loader_test:\n",
    "    graph.to(device)\n",
    "    graph_copy = graph.clone().to(device)\n",
    "    \n",
    "\n",
    "    num_graphs = graph_copy.batch.max().item() + 1\n",
    "    idx_list = []\n",
    "\n",
    "    for graph_idx in range(num_graphs):\n",
    "        # Find node indices belonging to this graph\n",
    "        node_idx = (graph_copy.batch == graph_idx).nonzero(as_tuple=False).view(-1)\n",
    "        # Randomly permute and select n_obs (or fewer if graph has fewer nodes)\n",
    "        k = min(n_obs, node_idx.size(0))\n",
    "        perm = torch.randperm(node_idx.size(0), device=graph.x.device)[:k]\n",
    "        idx_list.append(node_idx[perm])\n",
    "\n",
    "    # Concatenate selected indices\n",
    "    idx = torch.cat(idx_list, dim=0).to(device)\n",
    "\n",
    "    # Zero out features\n",
    "    graph_copy.x.zero_()\n",
    "\n",
    "    # Copy and add noise\n",
    "    y = graph.x[idx].clone()\n",
    "    sigma_tc = get_rnd_stddev().to(device)\n",
    "    y += torch.randn_like(y) * sigma_tc\n",
    "\n",
    "    # Place noisy observations\n",
    "    graph_copy.x[idx] = y\n",
    "\n",
    "    # Decode/predict with your GCN\n",
    "    u_decode = gcn(graph_copy)\n",
    "\n",
    "    true =   graph.x\n",
    "    pred = u_decode\n",
    "\n",
    "    batch = graph.batch\n",
    "    num_graphs = batch.max().item() + 1\n",
    "\n",
    "    for i in range(num_graphs):\n",
    "        # Mask nodes belonging to graph i\n",
    "        mask = (batch == i)\n",
    "        true_i = true[mask]\n",
    "        pred_i = pred[mask]\n",
    "        mae_i = torch.mean(torch.abs(pred_i - true_i)).item()\n",
    "        mae_list.append(mae_i)\n",
    "    \n",
    "# End timer\n",
    "total_time = time.time() - start_time\n",
    "\n",
    "# Compute final stats\n",
    "mean_mae = np.mean(mae_list)\n",
    "std_mae = np.std(mae_list)\n",
    "\n",
    "print(f\"Mean MAE over {N_test} runs: {mean_mae:.6f}\")\n",
    "print(f\"Stddev of MAE: {std_mae:.6f}\")\n",
    "print(f\"Evaluation time: {total_time/1000:.5f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
