{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch.multiprocessing as mp\n",
    "# mp.set_start_method('spawn')\n",
    "\n",
    "import torch\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial import Delaunay\n",
    "from scipy.sparse import lil_matrix\n",
    "from scipy.sparse.linalg import spsolve\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "\n",
    "device = 'cuda' \n",
    "\n",
    "filename = 'heat_rect_varBC_graph_betterArch_ABC.ipynb'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "    \n",
    "    \n",
    "data = make_sln_graph(1., 1., res=20)\n",
    "# data = make_sln_graph(2*np.pi, 2*np.pi)\n",
    "\n",
    "plot_slngraph(data, data.x[:,0].cpu().detach().numpy())\n",
    "\n",
    "plt.tricontourf(data.pos.numpy()[:, 0], data.pos.numpy()[:, 1], data.face.numpy().T,\n",
    "                np.sin(data.pos.numpy()[:, 0])*np.sin(0.5 * data.pos.numpy()[:, 1]), levels=20)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.colorbar()\n",
    "plt.close()\n",
    "\n",
    "print(data.ubc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#! Making a small dataset\n",
    "N  = 1_000\n",
    "#! x_1, x_2 sizes\n",
    "r1, r2   = np.array([0.1, 0.1]), np.array([1., 1.,])\n",
    "\n",
    "#! u_left, u_right, u_top, u_bottom\n",
    "# url, urr = np.array([0., 0.1, 0., 0.]), np.array([0., 1., 0., 0.])\n",
    "# url, urr = np.array([0., 0.1, 0.1, 0.]), np.array([0., 1., 1., 0.])\n",
    "url, urr = np.array([0., 0.1, 0.0, 0.]), np.array([0., 1., 1., 0.])\n",
    "\n",
    "\n",
    "\n",
    "data_lh =  torch.rand(N, 2)\n",
    "data_lh = data_lh * (r2 - r1)[None, :] + r1[None, :]\n",
    "plt.scatter(data_lh[:, 0], data_lh[:, 1] )\n",
    "plt.show()\n",
    "\n",
    "data_ubc = np.random.rand(N, 4)\n",
    "data_ubc = data_ubc * (urr - url)[None, :] + url[None, :]\n",
    "plt.scatter(data_ubc[:, 0], data_ubc[:, 1] )\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_list = [make_sln_graph(lh[0], lh[1], ubc, res=20) for lh, ubc in zip(data_lh, data_ubc)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir = './models/RectHeat_GABI_6/'\n",
    "dir_plt = dir+'plt/'\n",
    "os.makedirs(dir, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# with open(dir+'data_list.pkl', 'wb') as f:\n",
    "#     pickle.dump(data_list, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir+'data_list.pkl', 'rb') as f:\n",
    "    data_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    \n",
    "    plot_slngraph(data_list[i], data_list[i].x[:,0].cpu().detach().numpy(),\n",
    "                  save=f\"models/RectHeat_GABI_5/plt/data_{i}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data = next(iter(loader)).to(device)\n",
    "# print(data.num_node_features)\n",
    "\n",
    "dim_z = 100\n",
    "# dim_z = 50\n",
    "\n",
    "dim_u = 1\n",
    "dim_x = 2\n",
    "dim_node_features = dim_x + dim_u\n",
    "class GCN_E(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_E, self).__init__()\n",
    "\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv(dim_node_features*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            # GCNConv(self.c*2, self.c), \n",
    "            # GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_z),\n",
    "        ])\n",
    "\n",
    "    def forward(self, u, graph):\n",
    "        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr\n",
    "        x = torch.concat([u, pos], dim=1)\n",
    "        \n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "\n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        x = pool.global_mean_pool(x, graph.batch)\n",
    "        \n",
    "        return x\n",
    "\n",
    "\n",
    "class GCN_D(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GCN_D, self).__init__()\n",
    "        self.c = dim_z\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((dim_z + 2)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            # GCNConv(self.c*2, self.c),  \n",
    "            # GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, z, graph):\n",
    "\n",
    "        z_expanded = z[graph.batch]\n",
    "\n",
    "        x = torch.cat([z_expanded,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "    \n",
    "gcnE = GCN_E().to(device)\n",
    "gcnD = GCN_D().to(device)\n",
    "\n",
    "# from torch.optim.lr_scheduler import ExponentialLR\n",
    "optimizerE = torch.optim.Adam(gcnE.parameters(), lr=0.001)\n",
    "# schedulerE = ExponentialLR(optimizerE, gamma=0.95)\n",
    "optimizerD = torch.optim.Adam(gcnD.parameters(), lr=0.001)\n",
    "# schedulerD = ExponentialLR(optimizerD, gamma=0.95)\n",
    "\n",
    "\n",
    "from GABI.stat import MMDLoss\n",
    "mmd = MMDLoss()\n",
    "\n",
    "def loss_function(data):\n",
    "        z = gcnE(data.x, data)\n",
    "        data_clone = data.clone()\n",
    "        # data_clone.x = torch.zeros( ( data_clone.x.shape[0], dim_z+nfreq*2) ).to(device)\n",
    "        u = gcnD(z, data_clone)\n",
    "        lossL = torch.mean((data.x - u)**2.)\n",
    "        Xd = torch.randn_like(z).to(device)\n",
    "        LossD = mmd(z, Xd)\n",
    "        loss = lossL + LossD\n",
    "        return loss, (lossL, LossD)\n",
    "\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(count_parameters(gcnE))\n",
    "print(count_parameters(gcnD))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_list = [data.to(device) for data in data_list]\n",
    "loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "# data =  data.to(device)\n",
    "# data = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "gcnE.train()\n",
    "gcnD.train()\n",
    "\n",
    "LOSS = []\n",
    "LOSS_D = []\n",
    "LOSS_L = []\n",
    "import time\n",
    "time_train_start = time.time()\n",
    "# for epoch in range(100_000):\n",
    "# for epoch in range(10_000):\n",
    "for epoch in range(20_000):\n",
    "\n",
    "    start_time_b = time.time()\n",
    "    for data in loader:\n",
    "        optimizerE.zero_grad()\n",
    "        optimizerD.zero_grad()\n",
    "        data.to(device)\n",
    "        loss, aux = loss_function(data)\n",
    "        lossL, LossD = aux\n",
    "        LOSS.append(loss.cpu().detach().numpy())\n",
    "        LOSS_D.append(LossD.cpu().detach().numpy())\n",
    "        LOSS_L.append(lossL.cpu().detach().numpy())\n",
    "        loss.backward()\n",
    "        optimizerE.step()\n",
    "        optimizerD.step()\n",
    "    # schedulerE.step()\n",
    "    # schedulerD.step()\n",
    "    print(f'time epoch = {time.time() - start_time_b:.3f}s', )\n",
    "        \n",
    "    print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n",
    "    \n",
    "time_train = time.time() - time_train_start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcnE.eval()\n",
    "gcnD.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'text.latex.preamble': r'\\usepackage{amsfonts, amsmath, amssymb}',\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'axes.labelsize':   18,\n",
    "    'axes.titlesize':   18,\n",
    "    'xtick.labelsize' : 16,\n",
    "    'ytick.labelsize' : 16,\n",
    "          })\n",
    "# latex font definition\n",
    "plt.rc('legend',fontsize=14)\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', **{'family':'serif','serif':['Computer Modern Roman']})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogy(LOSS)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "plt.semilogy(LOSS_L, label='Loss $\\mathsf{L}$')\n",
    "plt.semilogy(LOSS_D, label='Loss $\\mathsf{d}$')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.savefig(dir_plt+ 'losses.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE MODELS\n",
    "torch.save(gcnE.state_dict(), dir+'gcnE.model')\n",
    "torch.save(optimizerE.state_dict(), dir+'gcnE.opt')\n",
    "torch.save(gcnD.state_dict(), dir+'gcnD.model')\n",
    "torch.save(optimizerD.state_dict(), dir+'gcnD.opt')\n",
    "\n",
    "\n",
    "loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D, 'time_train':time_train,}\n",
    "import pickle\n",
    "with open(dir+'loss_dict.pkl', 'wb') as f:\n",
    "    pickle.dump(loss_dict, f)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(dir+'loss_dict.pkl', 'rb') as f:\n",
    "    loss_dict = pickle.load(f)\n",
    "LOSS = loss_dict['LOSS']\n",
    "LOSS_L = loss_dict['LOSS_L']\n",
    "LOSS_D = loss_dict['LOSS_D']\n",
    "time_train = loss_dict['time_train']\n",
    "\n",
    "print('time_train:', time_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gcnE = GCN_E().to(device)\n",
    "# gcnD = GCN_D().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD MODELS\n",
    "gcnE.load_state_dict(torch.load(dir+'gcnE.model'))\n",
    "# gcnE = torch.load(dir+'gcnE.model', weights_only=False),\n",
    "gcnE.eval()\n",
    "# optimizerE.load_state_dict(torch.load(dir+'gcnE.opt'))\n",
    "# \n",
    "gcnD.load_state_dict(torch.load(dir+'gcnD.model'))\n",
    "# gcnD = torch.load(dir+'gcnD.model', weights_only=False),\n",
    "gcnD.eval()\n",
    "# optimizerD.load_state_dict(torch.load(dir+'gcnD.opt'))\n",
    "\n",
    "# gcnE.train()\n",
    "# gcnD.train()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = gcnE(data.x, data)\n",
    "z_hist = z.ravel().detach().cpu().numpy()\n",
    "plt.hist(z_hist, bins=30, density=True, label='Latent $z$ histogram')\n",
    "lin = np.linspace(-10, 10, 100)\n",
    "# plt.xlim(-5, 5)\n",
    "plt.plot(lin, np.exp(-lin**2/2)/np.sqrt(2*np.pi), label='Standard Gaussian')\n",
    "plt.legend()\n",
    "plt.savefig(dir_plt+ 'hist_latent.pdf')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_test = 1000\n",
    "data_lh =  torch.rand(N_test, 2)\n",
    "data_lh = data_lh * (r2 - r1)[None, :] + r1[None, :]\n",
    "plt.scatter(data_lh[:, 0], data_lh[:, 1] )\n",
    "plt.show()\n",
    "\n",
    "data_ubc = np.random.rand(N_test, 4)\n",
    "data_ubc = data_ubc * (urr - url)[None, :] + url[None, :]\n",
    "data_test_list = [make_sln_graph(lh[0], lh[1], ubc, res=20) for lh, ubc in zip(data_lh, data_ubc)]\n",
    "\n",
    "import pickle\n",
    "with open(dir+'data_test_list.pkl', 'wb') as f:\n",
    "    pickle.dump(data_test_list, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir+'data_test_list.pkl', 'rb') as f:\n",
    "    data_test_list = pickle.load(f)\n",
    "N_test = len(data_test_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(samples)\n",
    "# plt.show()\n",
    "# print(max_prior_z)\n",
    "# print(samples.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### ABC\n",
    "idx_text = 0\n",
    "data_test_graph = data_test_list[idx_text].clone()\n",
    "n_obs = 10\n",
    "sigma_tc = torch.FloatTensor([0.01])\n",
    "\n",
    "# ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "ObsIdx = np.array([6, 57, 81, 152, 175, 265, 302, 402, 453, 350])\n",
    "ObsIdx = torch.tensor(ObsIdx)\n",
    "y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0])\n",
    "y_n = y_n.to(device)\n",
    "    \n",
    "batch_size = 100\n",
    "graph_test = data_test_list[idx_text]\n",
    "graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    " \n",
    "graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "eps = 0.1\n",
    "u_decode = []\n",
    "z_acc = []\n",
    "u_acc = []\n",
    "n_loops = 1000\n",
    "\n",
    "epss = np.exp(np.linspace(np.log(1e-2), np.log(1.), 10))\n",
    "epss = torch.tensor(epss).to(device)\n",
    "\n",
    "z_acc_epss = [[] for _ in range(epss.shape[0])]\n",
    "\n",
    "\n",
    "norms_list   = torch.tensor([np.inf])\n",
    "z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "n_acc = 1000\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "    \n",
    "sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "\n",
    "\n",
    "loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "# z_abc = torch.vstack(z_acc)\n",
    "with torch.no_grad():\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "# u_abc_decode.shape\n",
    "\n",
    "acc_ratio = z_abc.shape[0] / (batch_size * n_loops)\n",
    "print('acc_ratio = ', acc_ratio)\n",
    "print('num samples = ', z_abc.shape[0])\n",
    "\n",
    "u_min_norm = gcnD(z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "print('best norm:', norms_list[sorted_index[0]])\n",
    "\n",
    "samples_decode_np = u_abc_decode.cpu().detach().numpy()[:, :, 0]\n",
    "samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "samples_decode_std = np.std(samples_decode_np, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print('data')\n",
    "plot_slngraph(graph_test, graph_test.x[:,0].cpu().detach().numpy(),\n",
    "              ObsIdx=ObsIdx,\n",
    "              save=dir+'plt/true.png')\n",
    "\n",
    "print('u_min_norm')\n",
    "plot_slngraph(graph_test, u_min_norm,\n",
    "              save=dir+'plt/mode_pred.png')\n",
    "\n",
    "print('post_mean')\n",
    "plot_slngraph(graph_test, samples_decode_mean,\n",
    "              save=dir+'plt/mean_pred.png')\n",
    "\n",
    "print('post_std')\n",
    "plot_slngraph(graph_test, (samples_decode_std),\n",
    "              ObsIdx=ObsIdx,\n",
    "              save=dir+'plt/std_pred.png')\n",
    "\n",
    "print('error')\n",
    "plot_slngraph(graph_test, np.abs(samples_decode_mean[:] - \n",
    "              graph_test.x[:,0].cpu().detach().numpy()),\n",
    "              ObsIdx=ObsIdx,\n",
    "              save=dir+'plt/error.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# u_pred = u_mode[:,0].cpu().detach().numpy()\n",
    "u_pred = samples_decode_mean\n",
    "u_pred = u_min_norm\n",
    "\n",
    "u_true =  graph_test.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "print(prec_in_1std, prec_in_2std)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_density = [100, 370, 150, 275]\n",
    "\n",
    "colors = ['tab:orange', 'tab:red', 'tab:green', 'tab:brown']\n",
    "u_true = data_test_graph.x.detach().cpu().numpy()[:, 0]\n",
    "offset = 0.001\n",
    "for j, idx in enumerate(idx_density):\n",
    "    \n",
    "    plt.hist(samples_decode_np[:,idx], density=True,color=colors[j], bins=15)\n",
    "    plt.axvline(u_true[idx],  linestyle='--', color='k', alpha=0.4)# color=colors[j])\n",
    "    # plt.axvline(u_true[idx]+offset,  linestyle='-', color='k', alpha=0.6, linewidth=1.)# color=colors[j])\n",
    "    # plt.axvline(u_true[idx],  linestyle='--', color=colors[j], alpha=0.6)# color=colors[j])\n",
    "    # plt.axvline(u_true[idx]-offset,  linestyle='-', color='k', alpha=0.6, linewidth=1.)# color=colors[j])\n",
    "    # plt.scatter(u_true[idx], 0, color=colors[j], s=10, edgecolors='k')\n",
    "\n",
    "plt.xlabel('Predicted $u$')\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir_plt+'sln_hist.pdf', dpi=300)\n",
    "plt.show()\n",
    "\n",
    "points = graph_test.pos.detach().cpu().numpy()\n",
    "\n",
    "tri_plot = mtri.Triangulation(graph_test.pos[:,0].detach().cpu().numpy(),\n",
    "                            graph_test.pos[:,1].detach().cpu().numpy(),\n",
    "                            triangles=graph_test.face.T.detach().cpu().numpy())\n",
    "\n",
    "plt.tricontourf(points[:,0],\n",
    "                points[:,1],\n",
    "                graph_test.face.detach().cpu().numpy().T,\n",
    "                graph_test.x[:,0].detach().cpu().numpy(), levels=20)\n",
    "plt.colorbar()\n",
    "contour = plt.tricontour(points[:, 0], points[:, 1], graph_test.face.T.detach().cpu().numpy(),\n",
    "                         graph_test.x[:,0].detach().cpu().numpy()[:],\n",
    "                         levels=20, colors='k', linewidths=0.5)\n",
    "\n",
    "plt.triplot(tri_plot, color='white', alpha=0.2)\n",
    "# plt.scatter(graph_test.pos[:,0].detach().cpu().numpy(),\n",
    "            # graph_test.pos[:,1].detach().cpu().numpy(), c='k', s=10, alpha=0.2)\n",
    "\n",
    "for j, idx in enumerate(idx_density):\n",
    "    plt.scatter(graph_test.pos[:,0][idx].detach().cpu().numpy(),\n",
    "                graph_test.pos[:,1][idx].detach().cpu().numpy(), \n",
    "                # c= y_n.detach().cpu().numpy(),\n",
    "                # c='r',\n",
    "                c=colors[j],\n",
    "                edgecolors ='k',\n",
    "                s=75, alpha=1.)\n",
    "# plt.title(f'True')\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(dir_plt+'sln_hist_locs.png')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "sns.pairplot(pd.DataFrame(z_abc[:, :10].detach().cpu().numpy()), diag_kind='hist', kind='scatter')\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "max_prop_to_date = 9e9\n",
    "max_prior_z = None\n",
    "ctr = 0\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 10\n",
    "sigma_tc = torch.FloatTensor([sigma])\n",
    "    \n",
    "def run_pred(data_test_graph):\n",
    "    \n",
    "    # data_test_graph = data_test_graph.to(device)\n",
    "\n",
    "    ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx)\n",
    "    y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0])\n",
    "    y_n = y_n.to(device)\n",
    "        \n",
    "        \n",
    "    time_start = time.time()\n",
    "    batch_size = 100\n",
    "    graph_test = data_test_graph\n",
    "    graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "    loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "    graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    "    graph_test_batch_loaded.require_grad = False\n",
    "    \n",
    "    graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "    global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "    flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "    n_loops = 100\n",
    "    \n",
    "    n_acc = 100\n",
    "\n",
    "    epss = np.exp(np.linspace(np.log(1e-2), np.log(1.), 50))\n",
    "    epss = torch.tensor(epss).to(device)\n",
    "    \n",
    "    norms_list   = torch.tensor([np.inf])\n",
    "    z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "        \n",
    "    sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "    z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "            \n",
    "    # print(z_abc.shape)\n",
    "\n",
    "    loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "    graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "    # z_abc = torch.vstack(z_acc)\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "        \n",
    "    samples_decode_np = u_abc_decode.cpu().detach().numpy()[:, :, 0]\n",
    "    samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "    samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    u_min_norm = gcnD( z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "    u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "    \n",
    "    u_pred = samples_decode_mean\n",
    "    # u_pred = u_min_norm\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    print('time_taken:', time_taken)\n",
    "    # u_pred = samples_decode_mean\n",
    "\n",
    "    u_true =  data_test_graph.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "    MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "    MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "    prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "    print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "    print(prec_in_1std, prec_in_2std)\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "N_test = 1000\n",
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': []}\n",
    "for i in range(N_test):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "        MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken = run_pred(data_test_list[i])\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "\n",
    "with open(dir + 'pred_stats_ABC.pkl', 'wb') as f:\n",
    "    pickle.dump(pred_stats, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "max_prop_to_date = 9e9\n",
    "max_prior_z = None\n",
    "ctr = 0\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 10\n",
    "sigma_tc = torch.FloatTensor([sigma])\n",
    "    \n",
    "def run_pred(data_test):\n",
    "    \n",
    "    data_test = data_test.to(device)\n",
    "\n",
    "   \n",
    "    # device='cuda'\n",
    "\n",
    "\n",
    "    # gcnD = gcnD.to(device)\n",
    "    # y_n = y_n.to(device)\n",
    "\n",
    "    sigma_tc = torch.FloatTensor([sigma]).to(device) \n",
    "    loader_test = DataLoader([data_test], batch_size=1)\n",
    "    data_test_graph = next(iter(loader_test)).to(device)\n",
    "    data_test_graph = data_test_graph.clone()\n",
    "    \n",
    "    ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx).to(device)\n",
    "    y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0]).to(device)\n",
    "    y_n = y_n.to(device)\n",
    "    # data_test_graph.x = torch.zeros( ( data_test_graph.x.shape[0], dim_z+nfreq*2) ).to(device)   \n",
    "\n",
    "    global max_prop_to_date, max_prior_z, ctr\n",
    "    max_prop_to_date = 9e9\n",
    "    max_prior_z = None\n",
    "    ctr = 0\n",
    "\n",
    "    def log_posterior(param):\n",
    "        z = param['z']\n",
    "        log_prior = 0.5 * torch.sum(z ** 2, dim=1)\n",
    "        \n",
    "        # data_test_graph_clone = data_test_graph.clone()\n",
    "        # u_z = gcnD(z, data_test_graph_clone)\n",
    "        u_z = gcnD(z, data_test_graph)\n",
    "\n",
    "        # u_z = gcnD(z, data_test_graph)\n",
    "        \n",
    "        hat_y = (u_z).reshape(-1,)[ObsIdx]\n",
    "        log_like  = 0.5/sigma_tc[0]**2. * torch.sum((y_n - hat_y)**2)\n",
    "        log_prob = log_prior + log_like\n",
    "            \n",
    "        global max_prop_to_date, max_prior_z\n",
    "        \n",
    "        if log_prob < max_prop_to_date:\n",
    "            # print(\"new map: log_prob\")\n",
    "            max_prop_to_date = log_prob\n",
    "            max_prior_z = z.clone()\n",
    "            \n",
    "        global ctr\n",
    "        ctr += 1\n",
    "        \n",
    "        return log_prob\n",
    "\n",
    "    from pyro.infer import MCMC, NUTS\n",
    "\n",
    "    kernel = NUTS(potential_fn=log_posterior, step_size=0.1, full_mass=False, jit_compile=True, max_tree_depth=10)\n",
    "\n",
    "    mcmc = MCMC(kernel, \n",
    "                initial_params =  {'z':torch.randn(1, dim_z).to(device)},\n",
    "                num_samples=100,\n",
    "                warmup_steps=100,\n",
    "                num_chains=1,\n",
    "                disable_progbar = True\n",
    "                # mp_context='spawn'\n",
    "                )\n",
    "\n",
    "    time_start = time.time()\n",
    "\n",
    "    mcmc.run()\n",
    "\n",
    "    samples = mcmc.get_samples()['z']\n",
    "    samples = samples.reshape(-1, dim_z)\n",
    "    # mcmc.summary()  \n",
    "    samples = samples.cpu().detach().numpy()\n",
    "\n",
    "    # print('num evals = ', ctr)\n",
    "    \n",
    "    samples_decode = torch.FloatTensor(samples[::]).to(device)\n",
    "    \n",
    "    new_u_pred = []\n",
    "    for i in range(samples_decode.shape[0]):\n",
    "        u_decode = gcnD(samples_decode[i,:][None,...], data_test_graph)\n",
    "        new_u_pred.append(u_decode)\n",
    "    new_u_pred = torch.stack(new_u_pred, dim=0)\n",
    "    \n",
    "    mode = max_prior_z\n",
    "    # mode = torch.FloatTensor(mode).to(device, dtype=torch.float32)\n",
    "    u_mode = gcnD(mode, data_test_graph)\n",
    "    data_test_graph_c = data_test_graph.clone()\n",
    "    data_test_graph_c.x = u_mode\n",
    "    \n",
    "    samples_decode_np = new_u_pred.cpu().detach().numpy()[:, :, 0]\n",
    "    samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "    samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    # u_pred = u_mode[:,0].cpu().detach().numpy()\n",
    "    u_pred = samples_decode_mean\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    # print('time_taken:', time_taken)\n",
    "    # u_pred = samples_decode_mean\n",
    "\n",
    "    u_true =  data_test.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "    MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "    MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "    prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "    # print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "    # print(prec_in_1std, prec_in_2std)\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "N_test = 100\n",
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [],\n",
    "              'prec_in_2std': [], 'time_taken': [], 'ctr':[]}\n",
    "for i in range(N_test):\n",
    "\n",
    "    # with torch.no_grad():\n",
    "    print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "    MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr = run_pred(data_test_list[i])\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    pred_stats['ctr'].append(ctr)\n",
    "\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "pred_stats['ctr'] = np.array(pred_stats['ctr'])\n",
    "\n",
    "with open(dir + 'pred_stats_NUTS.pkl', 'wb') as f:\n",
    "    pickle.dump(pred_stats, f)\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "print(\"ctr: \", np.mean(pred_stats['ctr']), \"±\", np.std(pred_stats['ctr']))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "max_prop_to_date = 9e9\n",
    "max_prior_z = None\n",
    "ctr = 0\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 30\n",
    "sigma_tc = torch.FloatTensor([sigma])\n",
    "    \n",
    "def run_pred(data_test_graph):\n",
    "    \n",
    "    # data_test_graph = data_test_graph.to(device)\n",
    "\n",
    "    ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx)\n",
    "    y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0])\n",
    "    y_n = y_n.to(device)\n",
    "        \n",
    "        \n",
    "    time_start = time.time()\n",
    "    batch_size = 100\n",
    "    graph_test = data_test_graph\n",
    "    graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "    loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "    graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    "    graph_test_batch_loaded.require_grad = False\n",
    "    \n",
    "    graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "    global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "    flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "    n_loops = 100\n",
    "    \n",
    "    n_acc = 100\n",
    "\n",
    "    epss = np.exp(np.linspace(np.log(1e-2), np.log(1.), 50))\n",
    "    epss = torch.tensor(epss).to(device)\n",
    "    \n",
    "    norms_list   = torch.tensor([np.inf])\n",
    "    z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "    for i in range(n_loops):\n",
    "        z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "        u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "        y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "        y_selected = y_selected.view(batch_size, n_obs)\n",
    "        y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "        norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "        norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "        z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "        \n",
    "    sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "    z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "            \n",
    "    # print(z_abc.shape)\n",
    "\n",
    "    loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "    graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "    # z_abc = torch.vstack(z_acc)\n",
    "    u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "        \n",
    "    samples_decode_np = u_abc_decode.cpu().detach().numpy()[:, :, 0]\n",
    "    samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "    samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    u_min_norm = gcnD( z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "    u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "    \n",
    "    # u_pred = samples_decode_mean\n",
    "    u_pred = u_min_norm\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    print('time_taken:', time_taken)\n",
    "    # u_pred = samples_decode_mean\n",
    "\n",
    "    u_true =  data_test_graph.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "    MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "    MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "    prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "    print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "    print(prec_in_1std, prec_in_2std)\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "N_test = 1000\n",
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': []}\n",
    "for i in range(N_test):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "        MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken = run_pred(data_test_list[i])\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "\n",
    "with open(dir + 'pred_stats_ABC_50obs.pkl', 'wb') as f:\n",
    "    pickle.dump(pred_stats, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "max_prop_to_date = 9e9\n",
    "max_prior_z = None\n",
    "ctr = 0\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 50\n",
    "sigma_tc = torch.FloatTensor([sigma])\n",
    "    \n",
    "def run_pred(data_test):\n",
    "    \n",
    "    data_test = data_test.to(device)\n",
    "\n",
    "   \n",
    "    # device='cuda'\n",
    "\n",
    "\n",
    "    # gcnD = gcnD.to(device)\n",
    "    # y_n = y_n.to(device)\n",
    "\n",
    "    sigma_tc = torch.FloatTensor([sigma]).to(device) \n",
    "    loader_test = DataLoader([data_test], batch_size=1)\n",
    "    data_test_graph = next(iter(loader_test)).to(device)\n",
    "    data_test_graph = data_test_graph.clone()\n",
    "    \n",
    "    ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "    ObsIdx = torch.tensor(ObsIdx).to(device)\n",
    "    y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0]).to(device)\n",
    "    y_n = y_n.to(device)\n",
    "    # data_test_graph.x = torch.zeros( ( data_test_graph.x.shape[0], dim_z+nfreq*2) ).to(device)   \n",
    "\n",
    "    global max_prop_to_date, max_prior_z, ctr\n",
    "    max_prop_to_date = 9e9\n",
    "    max_prior_z = None\n",
    "    ctr = 0\n",
    "\n",
    "    def log_posterior(param):\n",
    "        z = param['z']\n",
    "        log_prior = 0.5 * torch.sum(z ** 2, dim=1)\n",
    "        \n",
    "        # data_test_graph_clone = data_test_graph.clone()\n",
    "        # u_z = gcnD(z, data_test_graph_clone)\n",
    "        u_z = gcnD(z, data_test_graph)\n",
    "\n",
    "        # u_z = gcnD(z, data_test_graph)\n",
    "        \n",
    "        hat_y = (u_z).reshape(-1,)[ObsIdx]\n",
    "        log_like  = 0.5/sigma_tc[0]**2. * torch.sum((y_n - hat_y)**2)\n",
    "        log_prob = log_prior + log_like\n",
    "            \n",
    "        global max_prop_to_date, max_prior_z\n",
    "        \n",
    "        if log_prob < max_prop_to_date:\n",
    "            # print(\"new map: log_prob\")\n",
    "            max_prop_to_date = log_prob\n",
    "            max_prior_z = z.clone()\n",
    "            \n",
    "        global ctr\n",
    "        ctr += 1\n",
    "        \n",
    "        return log_prob\n",
    "\n",
    "    from pyro.infer import MCMC, NUTS\n",
    "\n",
    "    kernel = NUTS(potential_fn=log_posterior, step_size=0.1, full_mass=False, jit_compile=True, max_tree_depth=10)\n",
    "\n",
    "    mcmc = MCMC(kernel, \n",
    "                initial_params =  {'z':torch.randn(1, dim_z).to(device)},\n",
    "                num_samples=100,\n",
    "                warmup_steps=100,\n",
    "                num_chains=1,\n",
    "                disable_progbar = True\n",
    "                # mp_context='spawn'\n",
    "                )\n",
    "\n",
    "    time_start = time.time()\n",
    "\n",
    "    mcmc.run()\n",
    "\n",
    "    samples = mcmc.get_samples()['z']\n",
    "    samples = samples.reshape(-1, dim_z)\n",
    "    # mcmc.summary()  \n",
    "    samples = samples.cpu().detach().numpy()\n",
    "\n",
    "    # print('num evals = ', ctr)\n",
    "    \n",
    "    samples_decode = torch.FloatTensor(samples[::]).to(device)\n",
    "    \n",
    "    new_u_pred = []\n",
    "    for i in range(samples_decode.shape[0]):\n",
    "        u_decode = gcnD(samples_decode[i,:][None,...], data_test_graph)\n",
    "        new_u_pred.append(u_decode)\n",
    "    new_u_pred = torch.stack(new_u_pred, dim=0)\n",
    "    \n",
    "    mode = max_prior_z\n",
    "    # mode = torch.FloatTensor(mode).to(device, dtype=torch.float32)\n",
    "    u_mode = gcnD(mode, data_test_graph)\n",
    "    data_test_graph_c = data_test_graph.clone()\n",
    "    data_test_graph_c.x = u_mode\n",
    "    \n",
    "    samples_decode_np = new_u_pred.cpu().detach().numpy()[:, :, 0]\n",
    "    samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "    samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "    \n",
    "    # u_pred = u_mode[:,0].cpu().detach().numpy()\n",
    "    u_pred = samples_decode_mean\n",
    "    \n",
    "    time_taken = time.time() - time_start\n",
    "    # print('time_taken:', time_taken)\n",
    "    # u_pred = samples_decode_mean\n",
    "\n",
    "    u_true =  data_test.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "    MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "    MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "    prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "    prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "    # print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "    # print(prec_in_1std, prec_in_2std)\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr\n",
    "\n",
    "    \n",
    "# run_pred()\n",
    "\n",
    "N_test = 100\n",
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [],\n",
    "              'prec_in_2std': [], 'time_taken': [], 'ctr':[]}\n",
    "for i in range(N_test):\n",
    "\n",
    "    # with torch.no_grad():\n",
    "    print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "    MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr = run_pred(data_test_list[i])\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    pred_stats['ctr'].append(ctr)\n",
    "\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "pred_stats['ctr'] = np.array(pred_stats['ctr'])\n",
    "\n",
    "with open(dir + 'pred_stats_NUTS_50obs.pkl', 'wb') as f:\n",
    "    pickle.dump(pred_stats, f)\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "print(\"ctr: \", np.mean(pred_stats['ctr']), \"±\", np.std(pred_stats['ctr']))\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
