{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch.multiprocessing as mp\n",
    "# mp.set_start_method('spawn')\n",
    "\n",
    "import torch\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.spatial import Delaunay\n",
    "from scipy.sparse import lil_matrix\n",
    "from scipy.sparse.linalg import spsolve\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.tri as mtri\n",
    "\n",
    "device = 'cuda' \n",
    "\n",
    "filename = 'ABCvsMCMC_Time_Accuracy.ipynb'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "    \n",
    "    \n",
    "data = make_sln_graph(1., 1., res=20)\n",
    "# data = make_sln_graph(2*np.pi, 2*np.pi)\n",
    "\n",
    "plot_slngraph(data, data.x[:,0].cpu().detach().numpy())\n",
    "\n",
    "plt.tricontourf(data.pos.numpy()[:, 0], data.pos.numpy()[:, 1], data.face.numpy().T,\n",
    "                np.sin(data.pos.numpy()[:, 0])*np.sin(0.5 * data.pos.numpy()[:, 1]), levels=20)\n",
    "plt.gca().set_aspect('equal')\n",
    "plt.colorbar()\n",
    "plt.close()\n",
    "\n",
    "print(data.ubc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dir = './models/RectHeat_GABI_5/'\n",
    "dir_save = './models/RectHeat_GABI_5_CAT/'\n",
    "dir_plt = dir+'plt/'\n",
    "os.makedirs(dir, exist_ok=True)\n",
    "os.makedirs(dir_plt, exist_ok=True)\n",
    "\n",
    "import os\n",
    "pwd = os.getcwd()\n",
    "print(pwd)\n",
    "os.system(f\"jupyter nbconvert {pwd}/{filename} --to python --output {pwd}/{dir}run_file.py\" )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir+'data_list.pkl', 'rb') as f:\n",
    "    data_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, MLP\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import pool\n",
    "\n",
    "from torch.nn import ModuleList\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn.dense import Linear\n",
    "\n",
    "\n",
    "dim_u = 1\n",
    "dim_x = 2\n",
    "dim_node_features = dim_x + dim_u\n",
    "class GCN_E(torch.nn.Module):\n",
    "    def __init__(self,  dim_z, dim_c):\n",
    "        super(GCN_E, self).__init__()\n",
    "\n",
    "        self.dim_z = dim_z\n",
    "        self.c     = dim_c\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv(dim_node_features*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c), \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, self.dim_z),\n",
    "        ])\n",
    "\n",
    "    def forward(self, u, graph):\n",
    "        pos, edge_index, edge_attr = graph.pos, graph.edge_index, graph.edge_attr\n",
    "        x = torch.concat([u, pos], dim=1)\n",
    "        \n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "\n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        x = pool.global_mean_pool(x, graph.batch)\n",
    "        \n",
    "        return x\n",
    "\n",
    "\n",
    "class GCN_D(torch.nn.Module):\n",
    "    def __init__(self,  dim_z, dim_c):\n",
    "        super(GCN_D, self).__init__()\n",
    "        self.dim_z = dim_z\n",
    "        self.c = dim_c\n",
    "        self.convs = ModuleList([\n",
    "            GCNConv((self.dim_z + 2)*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c*2, self.c),  \n",
    "            GCNConv(self.c*2, self.c),\n",
    "            GCNConv(self.c, dim_u),\n",
    "        ])\n",
    "        \n",
    "    def forward(self, z, graph):\n",
    "\n",
    "        z_expanded = z[graph.batch]\n",
    "\n",
    "        x = torch.cat([z_expanded,  graph.pos], dim=1)\n",
    "\n",
    "        edge_index, edge_attr = graph.edge_index, graph.edge_attr\n",
    "\n",
    "        for ctr, conv in enumerate(self.convs[:-1]):\n",
    "            x_global = pool.global_mean_pool(x, graph.batch)\n",
    "            global_x_expanded = x_global[graph.batch]  # [num_nodes, hidden_dim]\n",
    "            \n",
    "            x = torch.cat([x, global_x_expanded], dim=1)  # [num_nodes, hidden_dim * 2]\n",
    "        \n",
    "            x = F.silu(conv(x, edge_index, edge_weight=edge_attr))\n",
    "            \n",
    "        x = self.convs[-1](x, edge_index, edge_weight=edge_attr)\n",
    "\n",
    "        return x\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_list = [data.to(device) for data in data_list]\n",
    "loader = DataLoader(data_list, batch_size=100, shuffle=True)\n",
    "data =  data.to(device)\n",
    "data = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 10_000\n",
    "def train_and_save_model(dim_z, dim_c, dir_save):\n",
    "    \n",
    "    gcnE = GCN_E(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "    gcnD = GCN_D(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "\n",
    "    optimizerE = torch.optim.Adam(gcnE.parameters(), lr=0.001)\n",
    "    optimizerD = torch.optim.Adam(gcnD.parameters(), lr=0.001)\n",
    "\n",
    "    from GABI.stat import MMDLoss\n",
    "    mmd = MMDLoss()\n",
    "\n",
    "    def loss_function(data):\n",
    "            z = gcnE(data.x, data)\n",
    "            data_clone = data.clone()\n",
    "            # data_clone.x = torch.zeros( ( data_clone.x.shape[0], dim_z+nfreq*2) ).to(device)\n",
    "            u = gcnD(z, data_clone)\n",
    "            lossL = torch.mean((data.x - u)**2.)\n",
    "            Xd = torch.randn_like(z).to(device)\n",
    "            LossD = mmd(z, Xd)\n",
    "            loss = lossL + LossD\n",
    "            return loss, (lossL, LossD)\n",
    "\n",
    "\n",
    "    def count_parameters(model):\n",
    "        return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "    print(count_parameters(gcnE))\n",
    "    print(count_parameters(gcnD))\n",
    "\n",
    "    gcnE.train()\n",
    "    gcnD.train()\n",
    "\n",
    "    LOSS = []\n",
    "    LOSS_D = []\n",
    "    LOSS_L = []\n",
    "    import time\n",
    "    time_train_start = time.time()\n",
    "    for epoch in range(n_epochs):\n",
    "        start_time_b = time.time()\n",
    "        for data in loader:\n",
    "            optimizerE.zero_grad()\n",
    "            optimizerD.zero_grad()\n",
    "            data.to(device)\n",
    "            loss, aux = loss_function(data)\n",
    "            lossL, LossD = aux\n",
    "            LOSS.append(loss.cpu().detach().numpy())\n",
    "            LOSS_D.append(LossD.cpu().detach().numpy())\n",
    "            LOSS_L.append(lossL.cpu().detach().numpy())\n",
    "            loss.backward()\n",
    "            optimizerE.step()\n",
    "            optimizerD.step()\n",
    "        print(f'time batch = {time.time() - start_time_b:.3f}s', )\n",
    "            \n",
    "        print(\"Epoch: \", epoch, \"Loss: \", loss.item())\n",
    "        \n",
    "    time_train = time.time() - time_train_start\n",
    "    \n",
    "    gcnE.eval()\n",
    "    gcnD.eval()\n",
    "    \n",
    "    torch.save(gcnE.state_dict(), dir_save+'gcnE.model')\n",
    "    torch.save(optimizerE.state_dict(), dir_save+'gcnE.opt')\n",
    "    torch.save(gcnD.state_dict(), dir_save+'gcnD.model')\n",
    "    torch.save(optimizerD.state_dict(), dir_save+'gcnD.opt')\n",
    "    \n",
    "    print(f\"Training completed in {time_train:.2f} seconds\")\n",
    "    print(\"saved models to\", dir_save)\n",
    "    \n",
    "    loss_dict = {'LOSS': LOSS, 'LOSS_L': LOSS_L, 'LOSS_D': LOSS_D, 'time_train':time_train, 'dim_z':dim_z, 'dim_c':dim_c}\n",
    "\n",
    "    with open(dir_save+'loss_dict.pkl', 'wb') as f:\n",
    "        pickle.dump(loss_dict, f)\n",
    "    print(\"saved loss info to\", dir_save)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tests = [[8, 100], [16, 100], [32, 100], [64, 100], [128, 100], [256, 100], [512, 100]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    'text.latex.preamble': r'\\usepackage{amsfonts, amsmath, amssymb}',\n",
    "    \"font.family\": \"sans-serif\",\n",
    "    \"font.sans-serif\": [\"Helvetica\"],\n",
    "    'axes.labelsize':   18,\n",
    "    'axes.titlesize':   18,\n",
    "    'xtick.labelsize' : 16,\n",
    "    'ytick.labelsize' : 16,\n",
    "          })\n",
    "# latex font definition\n",
    "plt.rc('legend',fontsize=14)\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', **{'family':'serif','serif':['Computer Modern Roman']})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "for i in range(len(tests)):\n",
    "    dim_z = tests[i][0]\n",
    "    dim_c = tests[i][1]\n",
    "    dir_save = f'./models/RectHeat_GABI_5_CAT/model_{dim_z}_{dim_c}/'\n",
    "    os.makedirs(dir_save, exist_ok=True)\n",
    "    train_and_save_model(dim_z, dim_c, dir_save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open(dir+'data_test_list.pkl', 'rb') as f:\n",
    "    data_test_list = pickle.load(f)\n",
    "# N_test = len(data_test_list)\n",
    "N_test = 100\n",
    "\n",
    "sigma = 0.01\n",
    "n_obs = 10\n",
    "sigma_tc = torch.FloatTensor([sigma])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def run_models(dir, dim_z, dim_c):\n",
    "    \n",
    "    gcnE = GCN_E(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "    gcnD = GCN_D(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "\n",
    "    gcnE.load_state_dict(torch.load(dir+'gcnE.model'))\n",
    "    gcnD.load_state_dict(torch.load(dir+'gcnD.model'))\n",
    "    \n",
    "    def run_pred(data_test_graph):\n",
    "        \n",
    "        gcnE.load_state_dict(torch.load(dir+'gcnE.model'))\n",
    "        gcnE.eval()\n",
    "        gcnD.load_state_dict(torch.load(dir+'gcnD.model'))\n",
    "        gcnD.eval()\n",
    "\n",
    "\n",
    "        ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "        ObsIdx = torch.tensor(ObsIdx)\n",
    "        y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0])\n",
    "        y_n = y_n.to(device)\n",
    "            \n",
    "            \n",
    "        time_start = time.time()\n",
    "        batch_size = 100\n",
    "        graph_test = data_test_graph\n",
    "        graph_test_batch = [graph_test.clone() for _ in range(batch_size)] \n",
    "        loader_test = DataLoader(graph_test_batch, batch_size=batch_size)\n",
    "        graph_test_batch_loaded  = next(iter(loader_test)).to(device)\n",
    "        graph_test_batch_loaded.require_grad = False\n",
    "        \n",
    "        graph_offsets = torch.arange(batch_size) * graph_test.x.shape[0]  # [B]\n",
    "        global_indices = graph_offsets[:, None] + ObsIdx[None, :]  # [B, K]\n",
    "        flat_indices = global_indices.reshape(-1)  # [B*K]\n",
    "\n",
    "        n_loops = 100\n",
    "        \n",
    "        n_acc = 100\n",
    "\n",
    "        epss = np.exp(np.linspace(np.log(1e-2), np.log(1.), 50))\n",
    "        epss = torch.tensor(epss).to(device)\n",
    "        \n",
    "        norms_list   = torch.tensor([np.inf])\n",
    "        z_prior_list = torch.tensor(np.zeros((1, dim_z))).to(device, dtype=torch.float)  # Initialize with zeros\n",
    "\n",
    "        for i in range(n_loops):\n",
    "            z_prior = torch.randn(batch_size, dim_z).to(device)\n",
    "            u_decode = gcnD(z_prior, graph_test_batch_loaded)\n",
    "            y_selected = u_decode[flat_indices]  # [B*K, F]\n",
    "            y_selected = y_selected.view(batch_size, n_obs)\n",
    "            y_selected = y_selected + sigma_tc.to(device) * torch.randn_like(y_selected).to(device) # [B, K]\n",
    "            norms = torch.norm(y_selected - y_n[None, :], dim=-1)\n",
    "            norms_list = torch.cat( (norms_list, norms.detach().cpu()), dim=0 )\n",
    "            z_prior_list = torch.cat( (z_prior_list, z_prior.detach()), dim=0 )\n",
    "            \n",
    "        sorted_index = torch.argsort(norms_list[:], dim=0)  # [n_loops, B*K]\n",
    "        z_abc   = z_prior_list[sorted_index[:n_acc]]\n",
    "                \n",
    "        # print(z_abc.shape)\n",
    "\n",
    "        loader_decode_ABC = DataLoader([graph_test], batch_size=1)\n",
    "        graph_ABC_loaded = next(iter(loader_decode_ABC)).to(device)\n",
    "        # z_abc = torch.vstack(z_acc)\n",
    "        u_abc_decode = torch.vmap(gcnD, in_dims=(0, None))(z_abc[:, None, :], graph_ABC_loaded)\n",
    "\n",
    "            \n",
    "        samples_decode_np = u_abc_decode.cpu().detach().numpy()[:, :, 0]\n",
    "        samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "        samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "        \n",
    "        u_min_norm = gcnD( z_prior_list[sorted_index[:1]], graph_ABC_loaded)\n",
    "        u_min_norm = u_min_norm.cpu().detach().numpy()[:, 0]\n",
    "\n",
    "        \n",
    "        u_pred = samples_decode_mean\n",
    "        # u_pred = u_min_norm\n",
    "        \n",
    "        time_taken = time.time() - time_start\n",
    "        print('time_taken:', time_taken)\n",
    "        # u_pred = samples_decode_mean\n",
    "\n",
    "        u_true =  data_test_graph.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "        MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "        MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "        prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "        prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "        print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "        print(prec_in_1std, prec_in_2std)\n",
    "        \n",
    "        return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken\n",
    "\n",
    "        \n",
    "    # run_pred()\n",
    "\n",
    "    pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': []}\n",
    "    for i in range(N_test):\n",
    "\n",
    "        with torch.no_grad():\n",
    "            print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "            MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken = run_pred(data_test_list[i])\n",
    "        pred_stats['MAE'].append(MAE)\n",
    "        pred_stats['MaxAE'].append(MaxAE)\n",
    "        pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "        pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "        pred_stats['time_taken'].append(time_taken)\n",
    "        \n",
    "    pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "    pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "    pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "    pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "    pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "\n",
    "    print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "    print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "    print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "    print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "    print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "\n",
    "\n",
    "    with open(dir+'pred_stats_ABC.pkl', 'wb') as f:\n",
    "        pickle.dump(pred_stats, f)\n",
    "\n",
    "# tests = [[5, 100], [10, 100], [50, 100], [100, 100], [500, 500] ]\n",
    "for i in range(len(tests)):\n",
    "    dim_z = tests[i][0]\n",
    "    dim_c = tests[i][1]\n",
    "    dir_save = f'./models/RectHeat_GABI_5_CAT/model_{dim_z}_{dim_c}/'\n",
    "    run_models(dir_save, dim_z, dim_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "\n",
    "max_prop_to_date = 9e9\n",
    "max_prior_z = None\n",
    "ctr = 0\n",
    "    \n",
    "# sigma = 0.01\n",
    "# n_obs = 10\n",
    "# sigma_tc = torch.FloatTensor([sigma])\n",
    "\n",
    "\n",
    "from pyro.infer import MCMC, NUTS\n",
    "\n",
    "def run_models(dir, dim_z, dim_c):\n",
    "    \n",
    "    gcnE = GCN_E(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "    gcnD = GCN_D(dim_c=dim_c, dim_z=dim_z).to(device)\n",
    "\n",
    "    gcnE.load_state_dict(torch.load(dir+'gcnE.model'))\n",
    "    gcnD.load_state_dict(torch.load(dir+'gcnD.model'))\n",
    "    \n",
    "    gcnE.eval()\n",
    "    gcnD.eval()\n",
    "                \n",
    "                \n",
    "    def run_pred(data_test):\n",
    "        \n",
    "        data_test = data_test.to(device)\n",
    "\n",
    "        sigma_tc = torch.FloatTensor([sigma]).to(device) \n",
    "        loader_test = DataLoader([data_test], batch_size=1)\n",
    "        data_test_graph = next(iter(loader_test)).to(device)\n",
    "        data_test_graph = data_test_graph.clone()\n",
    "        \n",
    "        ObsIdx = np.random.choice(range(data_test_graph.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "        ObsIdx = torch.tensor(ObsIdx).to(device)\n",
    "        y_n = (data_test_graph.x).reshape(-1,)[ObsIdx] + sigma_tc * torch.randn(ObsIdx.shape[0]).to(device)\n",
    "        y_n = y_n.to(device)\n",
    "        # data_test_graph.x = torch.zeros( ( data_test_graph.x.shape[0], dim_z+nfreq*2) ).to(device)   \n",
    "\n",
    "        global max_prop_to_date, max_prior_z, ctr\n",
    "        max_prop_to_date = 9e9\n",
    "        max_prior_z = None\n",
    "        ctr = 0\n",
    "\n",
    "        def log_posterior(param):\n",
    "            z = param['z']\n",
    "            log_prior = 0.5 * torch.sum(z ** 2, dim=1)\n",
    "            \n",
    "            u_z = gcnD(z, data_test_graph)\n",
    "            \n",
    "            hat_y = (u_z).reshape(-1,)[ObsIdx]\n",
    "            log_like  = 0.5/sigma_tc[0]**2. * torch.sum((y_n - hat_y)**2)\n",
    "            log_prob = log_prior + log_like\n",
    "                \n",
    "            global max_prop_to_date, max_prior_z\n",
    "            \n",
    "            if log_prob < max_prop_to_date:\n",
    "                max_prop_to_date = log_prob\n",
    "                max_prior_z = z.clone()\n",
    "                \n",
    "            global ctr\n",
    "            ctr += 1\n",
    "            \n",
    "            return log_prob\n",
    "\n",
    "        kernel = NUTS(potential_fn=log_posterior, step_size=0.1, full_mass=False, jit_compile=True, max_tree_depth=10)\n",
    "\n",
    "        mcmc = MCMC(kernel, \n",
    "                    initial_params =  {'z':torch.randn(1, dim_z).to(device)},\n",
    "                    num_samples=100,\n",
    "                    warmup_steps=100,\n",
    "                    # num_samples=3,\n",
    "                    # warmup_steps=3,\n",
    "                    num_chains=1,\n",
    "                    # mp_context='spawn'\n",
    "                    )\n",
    "\n",
    "        time_start = time.time()\n",
    "\n",
    "        mcmc.run()\n",
    "\n",
    "        samples = mcmc.get_samples()['z']\n",
    "        samples = samples.reshape(-1, dim_z)\n",
    "        # mcmc.summary()  \n",
    "        # samples = samples.cpu().detach().numpy()\n",
    "        \n",
    "        # samples = torch.zeros((100, dim_z))\n",
    "\n",
    "        print('num evals = ', ctr)\n",
    "        \n",
    "        # samples_decode = torch.FloatTensor(samples[::]).to(device)\n",
    "        samples_decode = samples\n",
    "        \n",
    "        new_u_pred = []\n",
    "        for i in range(samples_decode.shape[0]):\n",
    "            u_decode = gcnD(samples_decode[i,:][None,...], data_test_graph)\n",
    "            new_u_pred.append(u_decode)\n",
    "        new_u_pred = torch.stack(new_u_pred, dim=0)\n",
    "        \n",
    "        # mode = max_prior_z\n",
    "        # mode = torch.FloatTensor(mode).to(device, dtype=torch.float32)\n",
    "        # u_mode = gcnD(mode, data_test_graph)\n",
    "        # data_test_graph_c = data_test_graph.clone()\n",
    "        # data_test_graph_c.x = u_mode\n",
    "        \n",
    "        samples_decode_np = new_u_pred.cpu().detach().numpy()[:, :, 0]\n",
    "        samples_decode_mean = np.mean(samples_decode_np, axis=0)\n",
    "        samples_decode_std = np.std(samples_decode_np, axis=0)\n",
    "        \n",
    "        u_pred = samples_decode_mean\n",
    "        \n",
    "        time_taken = time.time() - time_start\n",
    "        print('time_taken:', time_taken)\n",
    "        # u_pred = samples_decode_mean\n",
    "\n",
    "        u_true =  data_test.x[:,0].cpu().detach().numpy()\n",
    "\n",
    "        MAE = np.mean( np.abs(u_pred - u_true) )\n",
    "        MaxAE = np.max( np.abs(u_pred - u_true) )\n",
    "\n",
    "        prec_in_1std = np.mean( np.abs(u_pred - u_true) < samples_decode_std )\n",
    "        prec_in_2std = np.mean( np.abs(u_pred - u_true) < 2*samples_decode_std )\n",
    "\n",
    "        print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "        print(prec_in_1std, prec_in_2std)\n",
    "        \n",
    "        return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr\n",
    "\n",
    "        \n",
    "    # run_pred()\n",
    "\n",
    "    # N = 100\n",
    "    pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [],\n",
    "                'prec_in_2std': [], 'time_taken': [], 'ctr':[]}\n",
    "    for i in range(N_test):\n",
    "\n",
    "        # with torch.no_grad():\n",
    "        print(f'Running prediction for test data {i+1}/{N_test}')\n",
    "        MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, ctr = run_pred(data_test_list[i])\n",
    "        pred_stats['MAE'].append(MAE)\n",
    "        pred_stats['MaxAE'].append(MaxAE)\n",
    "        pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "        pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "        pred_stats['time_taken'].append(time_taken)\n",
    "        pred_stats['ctr'].append(ctr)\n",
    "\n",
    "        \n",
    "    pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "    pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "    pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "    pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "    pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "    pred_stats['ctr'] = np.array(pred_stats['ctr'])\n",
    "\n",
    "\n",
    "    print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "    print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "    print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "    print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "    print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))\n",
    "    print(\"ctr: \", np.mean(pred_stats['ctr']), \"±\", np.std(pred_stats['ctr']))\n",
    "\n",
    "    with open(dir+'pred_stats_NUTS.pkl', 'wb') as f:\n",
    "        pickle.dump(pred_stats, f)\n",
    "\n",
    "# tests = [[5, 100], [10, 100], [50, 100], [100, 100], [500, 500] ]\n",
    "for i in range(len(tests)):\n",
    "    dim_z = tests[i][0]\n",
    "    dim_c = tests[i][1]\n",
    "    dir_save = f'./models/RectHeat_GABI_5_CAT/model_{dim_z}_{dim_c}/'\n",
    "    run_models(dir_save, dim_z, dim_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# Initialize data containers\n",
    "dims = []\n",
    "results = {\n",
    "    'ABC': {'MAE_mean': [], 'MAE_std': [], 'MaxAE_mean': [], 'MaxAE_std': [],\n",
    "            'prec1_mean': [], 'prec1_std': [], 'prec2_mean': [], 'prec2_std': [],\n",
    "            'time_mean': [], 'time_std': []},\n",
    "    'NUTS': {'MAE_mean': [], 'MAE_std': [], 'MaxAE_mean': [], 'MaxAE_std': [],\n",
    "             'prec1_mean': [], 'prec1_std': [], 'prec2_mean': [], 'prec2_std': [],\n",
    "             'time_mean': [], 'time_std': []}\n",
    "}\n",
    "\n",
    "# Read data\n",
    "for dim_z, dim_c in tests:\n",
    "    dims.append(dim_z)\n",
    "    dir_path = f'./models/RectHeat_GABI_5_CAT/model_{dim_z}_{dim_c}/'\n",
    "    \n",
    "    # for method in ['ABC', 'ABC']:\n",
    "    for method in ['ABC', 'NUTS']:\n",
    "\n",
    "        file_path = os.path.join(dir_path, f'pred_stats_{method}.pkl')\n",
    "        \n",
    "        if os.path.exists(file_path):\n",
    "            with open(file_path, 'rb') as f:\n",
    "                pred_stats = pickle.load(f)\n",
    "            \n",
    "            # Compute and store statistics\n",
    "            mae = np.array(pred_stats['MAE'])\n",
    "            maxae = np.array(pred_stats['MaxAE'])\n",
    "            prec1 = np.array(pred_stats['prec_in_1std'])\n",
    "            prec2 = np.array(pred_stats['prec_in_2std'])\n",
    "            time = np.array(pred_stats['time_taken'])\n",
    "            \n",
    "            results[method]['MAE_mean'].append(np.mean(mae))\n",
    "            results[method]['MAE_std'].append(np.std(mae))\n",
    "            results[method]['MaxAE_mean'].append(np.mean(maxae))\n",
    "            results[method]['MaxAE_std'].append(np.std(maxae))\n",
    "            results[method]['prec1_mean'].append(np.mean(prec1))\n",
    "            results[method]['prec1_std'].append(np.std(prec1))\n",
    "            results[method]['prec2_mean'].append(np.mean(prec2))\n",
    "            results[method]['prec2_std'].append(np.std(prec2))\n",
    "            results[method]['time_mean'].append(np.mean(time))\n",
    "            results[method]['time_std'].append(np.std(time))\n",
    "        else:\n",
    "            print(f\"Warning: File not found {file_path}\")\n",
    "            # Append NaNs to keep alignment\n",
    "            for key in results[method]:\n",
    "                results[method][key].append(np.nan)\n",
    "\n",
    "dir_save_plt = './models/RectHeat_GABI_5_CAT/plt/'\n",
    "os.makedirs(dir_save_plt, exist_ok=True)\n",
    "\n",
    "## Plot 1: Inference time\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.errorbar(dims, results['ABC']['time_mean'], yerr=results['ABC']['time_std'],\n",
    "             fmt='-o', capsize=3, label='ABC')\n",
    "plt.errorbar(dims, results['NUTS']['time_mean'], yerr=results['NUTS']['time_std'],\n",
    "             fmt='-s', capsize=3, label='NUTS')\n",
    "plt.xlabel('$\\mathrm{dim}\\,\\mathbf{z}$')\n",
    "plt.ylabel('Inference time (s)')\n",
    "plt.xscale('log')\n",
    "plt.yscale('log')\n",
    "# plt.title('Inference time vs dim_z')\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "# plt.tight_layout()\n",
    "plt.savefig(dir_save_plt+'timeCAT.pdf')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# Plot 2: MAE\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.errorbar(dims, results['ABC']['MAE_mean'], yerr=results['ABC']['MAE_std'],\n",
    "             fmt='-o', capsize=3, label='ABC')\n",
    "plt.errorbar(dims, results['NUTS']['MAE_mean'], yerr=results['NUTS']['MAE_std'],\n",
    "             fmt='-s', capsize=3, label='NUTS')\n",
    "plt.xlabel('dim $\\mathbf{z}$')\n",
    "plt.ylabel('MAE')\n",
    "# plt.title('MAE ± 1 std vs dim_z')\n",
    "plt.xscale('log')\n",
    "plt.ylim(bottom=-0.001)\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "# plt.tight_layout()\n",
    "plt.savefig(dir_save_plt+'maeCAT.pdf')\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# Plot 3: prec_in_1std and prec_in_2std\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.plot(dims, results['ABC']['prec1_mean'], '-o', label='ABC $\\%$ 1 std')\n",
    "plt.plot(dims, results['NUTS']['prec1_mean'], '-s', label='NUTS $\\%$ 1 std')\n",
    "plt.plot(dims, results['ABC']['prec2_mean'], '-o', label='ABC $\\%$ 2 std')\n",
    "plt.plot(dims, results['NUTS']['prec2_mean'], '-s', label='NUTS $\\%$ 2 std')\n",
    "plt.xlabel('dim $\\mathbf{z}$')\n",
    "plt.ylabel('Percentage')\n",
    "# plt.title('Precision in 1 std and 2 std vs dim_z')\n",
    "plt.xscale('log')\n",
    "plt.legend(loc='lower right')\n",
    "plt.grid(True)\n",
    "# plt.tight_layout()\n",
    "plt.savefig(dir_save_plt+'perc_in_std_CAT.pdf')\n",
    "plt.show()\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gabi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
