{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Import a backend, we use torch in this example.\n",
    "import gpytorch\n",
    "import torch\n",
    "\n",
    "# Import the geometric_kernels backend.\n",
    "import geometric_kernels\n",
    "import geometric_kernels.torch\n",
    "\n",
    "# Import the Mesh space and the general-purpose MaternGeometricKernel\n",
    "from geometric_kernels.spaces.mesh import Mesh\n",
    "from geometric_kernels.kernels import MaternGeometricKernel\n",
    "\n",
    "# The GPyTorch frontend of GeometricKernels\n",
    "from geometric_kernels.frontends.gpytorch import GPyTorchGeometricKernel\n",
    "\n",
    "# Sampling routines we will use to create a dummy dataset\n",
    "from geometric_kernels.kernels import default_feature_map\n",
    "from geometric_kernels.sampling import sampler\n",
    "from geometric_kernels.utils.utils import make_deterministic\n",
    "\n",
    "# Stuff\n",
    "import numpy as np\n",
    "import optax\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "from pathlib import Path\n",
    "\n",
    "from GABI.solver.heat_rect_unstruct import make_sln_graph, plot_slngraph\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = make_sln_graph(0.5, 0.5, (0., 0.5, 0.25, 0.), res=20)\n",
    "\n",
    "plot_slngraph(data, data.x[:,0].cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_figure(fig):\n",
    "    \"\"\"Utility to clean up figure\"\"\"\n",
    "    fig.update_layout(scene_aspectmode=\"cube\")\n",
    "    fig.update_scenes(xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)\n",
    "    # fig.update_traces(showscale=False, hoverinfo=\"none\")\n",
    "    fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))\n",
    "\n",
    "    fig.update_layout(plot_bgcolor=\"rgba(0,0,0,0)\", paper_bgcolor=\"rgba(0,0,0,0)\")\n",
    "    fig.update_layout(\n",
    "        scene=dict(\n",
    "            xaxis=dict(showbackground=False, showticklabels=False, visible=False),\n",
    "            yaxis=dict(showbackground=False, showticklabels=False, visible=False),\n",
    "            zaxis=dict(showbackground=False, showticklabels=False, visible=False),\n",
    "        )\n",
    "    )\n",
    "    return fig\n",
    "\n",
    "def plot_mesh(mesh: Mesh, vertices_colors = None, **kwargs):\n",
    "    plot = go.Mesh3d(\n",
    "        x=mesh.vertices[:, 0],\n",
    "        y=mesh.vertices[:, 1],\n",
    "        z=mesh.vertices[:, 2],\n",
    "        i=mesh.faces[:, 0],\n",
    "        j=mesh.faces[:, 1],\n",
    "        k=mesh.faces[:, 2],\n",
    "        intensity=vertices_colors,\n",
    "        **kwargs\n",
    "    )\n",
    "    return plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_pos_3d = torch.nn.functional.pad(data.pos, (0,1)).detach().cpu().numpy()\n",
    "\n",
    "mesh = Mesh(\n",
    "    data_pos_3d,\n",
    "    data.face.T.detach().cpu().numpy()\n",
    ")\n",
    "\n",
    "print(mesh)\n",
    "mesh.num_vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the camera\n",
    "camera = dict(\n",
    "    up=dict(x=0, y=1, z=0),\n",
    "    center=dict(x=0, y=0, z=0),\n",
    "    eye=dict(x=0, y=0., z=2.0)\n",
    ")\n",
    "\n",
    "\n",
    "plot = plot_mesh(mesh)\n",
    "fig = go.Figure(plot)\n",
    "update_figure(fig)\n",
    "\n",
    "# fig.update_layout(\n",
    "#     scene=dict(\n",
    "#         zaxis=dict(visible=False),\n",
    "#         xaxis_title=\"x\",\n",
    "#         yaxis_title=\"y\"\n",
    "#     ),\n",
    "#     title=\"2D Mesh using Plotly Mesh3d\"\n",
    "# )\n",
    "\n",
    "# fig.show()\n",
    "\n",
    "\n",
    "\n",
    "edge_lines = []\n",
    "for a, b in data.edge_index.T:\n",
    "    edge_lines.append(go.Scatter3d(\n",
    "        x=[data_pos_3d[a,0], data_pos_3d[b,0], None],\n",
    "        y=[data_pos_3d[a,1], data_pos_3d[b,1], None],\n",
    "        z=[data_pos_3d[a,2], data_pos_3d[b,2], None],\n",
    "        mode='lines',\n",
    "        line=dict(color='black', width=3),\n",
    "        showlegend=False\n",
    "    ))\n",
    "\n",
    "\n",
    "# fig.show(\"png\")\n",
    "# Combine plot\n",
    "fig = go.Figure(data=[plot] + edge_lines)\n",
    "fig.update_layout(\n",
    "    scene=dict(\n",
    "        xaxis_title=\"x\",\n",
    "        yaxis_title=\"y\",\n",
    "        zaxis=dict(visible=False)\n",
    "    ),\n",
    "    title=\"2D Mesh with Edges\"\n",
    ")\n",
    "update_figure(fig)\n",
    "fig.update_layout(\n",
    "    scene_camera=camera\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_obs = 60\n",
    "# sigma = 0.01\n",
    "# ObsIdx = np.random.choice(range(data.pos.shape[0]), size=(n_obs,), replace=False)\n",
    "# y_n = (data.x).reshape(-1,)[ObsIdx] + sigma * np.random.randn(ObsIdx.shape[0])\n",
    "# y_n = y_n.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_data = 20\n",
    "key = torch.Generator()\n",
    "key.manual_seed(1234)\n",
    "\n",
    "xs_train = torch.randint(low=0, high=mesh.num_vertices, size=(num_data, 1), generator=key, dtype=torch.int64)\n",
    "xs_test = torch.arange(mesh.num_vertices, dtype=torch.int64)[:, None]\n",
    "# print(\"xs_train:\", xs_train)\n",
    "\n",
    "ObsIdx = xs_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_kernel = MaternGeometricKernel(mesh)\n",
    "\n",
    "params = base_kernel.init_params()\n",
    "params[\"lengthscale\"] = torch.tensor([5.0], dtype=torch.float64)\n",
    "params[\"nu\"]  = torch.tensor([3/2], dtype=torch.float64)\n",
    "\n",
    "feature_map = default_feature_map(kernel=base_kernel)\n",
    "sample_paths = make_deterministic(sampler(feature_map), key)\n",
    "\n",
    "# _, ys_train  = sample_paths(xs_train, params)\n",
    "ys_train = data.x[xs_train[:, 0]].reshape(-1,1).to(torch.float64)\n",
    "ys_train += torch.rand_like(ys_train) * 0.01\n",
    "\n",
    "key, ys_test = sample_paths(xs_test,  params)\n",
    "ys_test = data.x[xs_test[:, 0]].reshape(-1,1).to(torch.float64) \n",
    "\n",
    "ys_train = ys_train[:, 0]\n",
    "ys_test = ys_test[:, 0]\n",
    "\n",
    "# assert(torch.allclose((ys_test[xs_train[:, 0]]), ys_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel = gpytorch.kernels.ScaleKernel(\n",
    "            GPyTorchGeometricKernel(\n",
    "                base_kernel,\n",
    "                nu = params[\"nu\"],\n",
    "                lengthscale=params[\"lengthscale\"],\n",
    "                trainable_nu=False\n",
    "            )\n",
    "         )\n",
    "kernel.outputscale = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood, kernel):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ZeroMean()\n",
    "        self.covar_module = kernel\n",
    "\n",
    "    def forward(self, x):  # pylint: disable=arguments-differ\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood(\n",
    "    noise_constraint=gpytorch.constraints.GreaterThan(1e-6)\n",
    ")\n",
    "likelihood.noise = torch.tensor(1e-2)\n",
    "\n",
    "\n",
    "model = ExactGPModel(xs_train, ys_train, likelihood, kernel)\n",
    "\n",
    "# use float64:\n",
    "model.double()\n",
    "likelihood.double()\n",
    "\n",
    "print(\"Initial model:\")\n",
    "print(\"kernel.base_kernel.nu =\", model.covar_module.base_kernel.nu)\n",
    "print(\"kernel.base_kernel.lengthscale =\", model.covar_module.base_kernel.lengthscale)\n",
    "print(\"kernel.outputscale =\", model.covar_module.outputscale)\n",
    "print(\"likelihood.obs_noise =\", model.likelihood.noise)\n",
    "print(\"\")\n",
    "\n",
    "# Note: this is divided by the number of data points, hence may appear\n",
    "# quite different from the marginal log likelihoods of other frontends.\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "print(\"Initial negative log marginal likelihood:\", -mll(model(xs_train), ys_train).detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Put the model into training mode\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the Adam optimizer, with a set learning rate\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "# Set the number of training iterations\n",
    "n_iter = 1000\n",
    "\n",
    "print(\"Starting training...\")\n",
    "for i in range(n_iter):\n",
    "    # Set the gradients from previous iteration to zero\n",
    "    optimizer.zero_grad()\n",
    "    # Output from model\n",
    "    output = model(xs_train)\n",
    "    # Compute loss and backprop gradients\n",
    "    loss = -mll(output, ys_train)\n",
    "    loss.backward()\n",
    "\n",
    "    if i == 0 or (i+1) % 10 == 0:\n",
    "        print(\"Iter %d/%d - Loss: %.5f\" % (i + 1, n_iter, loss.item()))\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "print(\"\")\n",
    "print(\"Final model:\")\n",
    "print(\"kernel.base_kernel.nu =\", model.covar_module.base_kernel.nu)\n",
    "print(\"kernel.base_kernel.lengthscale =\", model.covar_module.base_kernel.lengthscale)\n",
    "print(\"kernel.outputscale =\", model.covar_module.outputscale)\n",
    "print(\"likelihood.obs_noise =\", model.likelihood.noise)\n",
    "print(\"\")\n",
    "\n",
    "print(\"Final negative log marginal likelihood:\", -mll(model(xs_train), ys_train).detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# switch model to prediction mode\n",
    "model.eval()\n",
    "\n",
    "# print(xs_train.shape, xs_test.shape, ys_train.shape, ys_test.shape)\n",
    "\n",
    "# predict mean and variance\n",
    "latent_dist = model(xs_test)\n",
    "posterior_mean = torch.reshape(latent_dist.mean, ys_test.shape).detach().numpy()\n",
    "posterior_std = torch.reshape(latent_dist.stddev, ys_test.shape).detach().numpy()\n",
    "\n",
    "# predict sample\n",
    "sample = latent_dist.sample(sample_shape=torch.Size([1])).detach().numpy()[0, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print('data')\n",
    "plot_slngraph(data, ys_test, ObsIdx=ObsIdx)\n",
    "\n",
    "print('sample')\n",
    "plot_slngraph(data, sample)\n",
    "\n",
    "print('post_mean')\n",
    "plot_slngraph(data, posterior_mean)\n",
    "\n",
    "print('post_std')\n",
    "plot_slngraph(data, posterior_std,\n",
    "              ObsIdx=ObsIdx)\n",
    "\n",
    "print('error')\n",
    "plot_slngraph(data, np.abs(ys_test - posterior_mean),\n",
    "              ObsIdx=ObsIdx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "posterior_mean = np.array( posterior_mean )\n",
    "ys_test = np.array( ys_test )\n",
    "MAE = np.mean( np.abs(posterior_mean - ys_test) )\n",
    "prec_in_1std = np.mean( np.abs(posterior_mean - ys_test) < posterior_std )\n",
    "prec_in_2std = np.mean( np.abs(posterior_mean - ys_test) < 2*posterior_std )\n",
    "\n",
    "print(f'MAE = {MAE:.4f}')\n",
    "\n",
    "print(f'NSE = ', np.linalg.norm(posterior_mean - ys_test)**2 / np.linalg.norm(ys_test)**2. ) \n",
    "\n",
    "\n",
    "print(prec_in_1std, prec_in_2std)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# N_test = 1000\n",
    "# r1, r2   = np.array([0.1, 0.1]), np.array([1., 1.,])\n",
    "# url, urr = np.array([0., 0.1, 0.1, 0.]), np.array([0., 1., 1., 0.])\n",
    "\n",
    "# data_lh =  torch.rand(N_test, 2)\n",
    "# data_lh = data_lh * (r2 - r1)[None, :] + r1[None, :]\n",
    "# plt.scatter(data_lh[:, 0], data_lh[:, 1] )\n",
    "# plt.show()\n",
    "\n",
    "# data_ubc = np.random.rand(N_test, 4)\n",
    "# data_ubc = data_ubc * (urr - url)[None, :] + url[None, :]\n",
    "# data_test_list = [make_sln_graph(lh[0], lh[1], ubc, res=20) for lh, ubc in zip(data_lh, data_ubc)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir = './models/RectHeat_GABI_5/'\n",
    "\n",
    "import pickle\n",
    "with open(dir+'data_test_list.pkl', 'rb') as f:\n",
    "    data_test_list = pickle.load(f)\n",
    "N_test = len(data_test_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rnd_stddev():\n",
    "    return torch.exp(torch.randn(1,) - 4.) + 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def run_pred(data_test, nu = 3/2):\n",
    "    num_data = 10\n",
    "    key = torch.Generator()\n",
    "    key.manual_seed(1234)\n",
    "\n",
    "    xs_train = torch.randint(low=0, high=mesh.num_vertices, size=(num_data, 1), generator=key, dtype=torch.int64)\n",
    "    xs_test = torch.arange(mesh.num_vertices, dtype=torch.int64)[:, None]\n",
    "    # print(\"xs_train:\", xs_train)\n",
    "\n",
    "    ObsIdx = xs_train\n",
    "    \n",
    "    base_kernel = MaternGeometricKernel(mesh)\n",
    "\n",
    "    params = base_kernel.init_params()\n",
    "    params[\"lengthscale\"] = torch.tensor([5.0], dtype=torch.float64)\n",
    "    params[\"nu\"]  = torch.tensor([nu], dtype=torch.float64)\n",
    "\n",
    "    feature_map = default_feature_map(kernel=base_kernel)\n",
    "    sample_paths = make_deterministic(sampler(feature_map), key)\n",
    "\n",
    "    _, ys_train  = sample_paths(xs_train, params)\n",
    "    ys_train = data_test.x[xs_train[:, 0]].reshape(-1,1).to(torch.float64)\n",
    "    sig_true = get_rnd_stddev()\n",
    "    ys_train += torch.rand_like(ys_train) * sig_true\n",
    "\n",
    "    key, ys_test = sample_paths(xs_test,  params)\n",
    "    ys_test = data_test.x[xs_test[:, 0]].reshape(-1,1).to(torch.float64) \n",
    "\n",
    "    ys_train = ys_train[:, 0]\n",
    "    ys_test = ys_test[:, 0]\n",
    "\n",
    "\n",
    "    kernel = gpytorch.kernels.ScaleKernel(\n",
    "            GPyTorchGeometricKernel(\n",
    "                base_kernel,\n",
    "                nu = params[\"nu\"],\n",
    "                lengthscale=params[\"lengthscale\"],\n",
    "                trainable_nu=False\n",
    "            )\n",
    "         )\n",
    "    kernel.outputscale = 1.0\n",
    "    \n",
    "    class ExactGPModel(gpytorch.models.ExactGP):\n",
    "        def __init__(self, train_x, train_y, likelihood, kernel):\n",
    "            super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "            self.mean_module = gpytorch.means.ZeroMean()\n",
    "            self.covar_module = kernel\n",
    "\n",
    "        def forward(self, x):  # pylint: disable=arguments-differ\n",
    "            mean_x = self.mean_module(x)\n",
    "            covar_x = self.covar_module(x)\n",
    "            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "    likelihood = gpytorch.likelihoods.GaussianLikelihood(\n",
    "        noise_constraint=gpytorch.constraints.GreaterThan(1e-6)\n",
    "    )\n",
    "    likelihood.noise = torch.tensor(1e-2)\n",
    "\n",
    "    model = ExactGPModel(xs_train, ys_train, likelihood, kernel)\n",
    "\n",
    "    # use float64:\n",
    "    model.double()\n",
    "    likelihood.double()\n",
    "\n",
    "    # print(\"Initial model:\")\n",
    "    # print(\"kernel.base_kernel.nu =\", model.covar_module.base_kernel.nu)\n",
    "    # print(\"kernel.base_kernel.lengthscale =\", model.covar_module.base_kernel.lengthscale)\n",
    "    # print(\"kernel.outputscale =\", model.covar_module.outputscale)\n",
    "    # print(\"likelihood.obs_noise =\", model.likelihood.noise)\n",
    "    # print(\"\")\n",
    "\n",
    "    # Note: this is divided by the number of data points, hence may appear\n",
    "    # quite different from the marginal log likelihoods of other frontends.\n",
    "    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "    # print(\"Initial negative log marginal likelihood:\", -mll(model(xs_train), ys_train).detach().numpy())\n",
    "\n",
    "    # Put the model into training mode\n",
    "    model.train()\n",
    "    likelihood.train()\n",
    "\n",
    "    # Use the Adam optimizer, with a set learning rate\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "\n",
    "    # Set the number of training iterations\n",
    "    n_iter = 500\n",
    "    star_time = time.time()\n",
    "    # print(\"Starting training...\")\n",
    "    for i in range(n_iter):\n",
    "        # Set the gradients from previous iteration to zero\n",
    "        optimizer.zero_grad()\n",
    "        # Output from model\n",
    "        output = model(xs_train)\n",
    "        # Compute loss and backprop gradients\n",
    "        loss = -mll(output, ys_train)\n",
    "        loss.backward()\n",
    "        # if i == 0 or (i+1) % 10 == 0:\n",
    "            # print(\"Iter %d/%d - Loss: %.5f\" % (i + 1, n_iter, loss.item()))\n",
    "        optimizer.step()\n",
    "\n",
    "    # print(\"\")\n",
    "    # print(\"Final model:\")\n",
    "    # print(\"kernel.base_kernel.nu =\", model.covar_module.base_kernel.nu)\n",
    "    # print(\"kernel.base_kernel.lengthscale =\", model.covar_module.base_kernel.lengthscale)\n",
    "    # print(\"kernel.outputscale =\", model.covar_module.outputscale)\n",
    "    # print(\"likelihood.obs_noise =\", model.likelihood.noise)\n",
    "    # print(\"\")\n",
    "\n",
    "    # print(\"Final negative log marginal likelihood:\", -mll(model(xs_train), ys_train).detach().numpy())\n",
    "    # switch model to prediction mode\n",
    "    model.eval()\n",
    "\n",
    "    # print(xs_train.shape, xs_test.shape, ys_train.shape, ys_test.shape)\n",
    "\n",
    "    # predict mean and variance\n",
    "    latent_dist = model(xs_test)\n",
    "    posterior_mean = torch.reshape(latent_dist.mean, ys_test.shape).detach().numpy()\n",
    "    posterior_std = torch.reshape(latent_dist.stddev, ys_test.shape).detach().numpy()\n",
    "\n",
    "    time_taken = time.time() - star_time\n",
    "\n",
    "    # predict sample\n",
    "    sample = latent_dist.sample(sample_shape=torch.Size([1])).detach().numpy()[0, :]\n",
    "    \n",
    "    posterior_mean = np.array( posterior_mean )\n",
    "    ys_test = np.array( ys_test )\n",
    "    MAE = np.mean( np.abs(posterior_mean - ys_test) )\n",
    "    MaxAE = np.max( np.abs(posterior_mean - ys_test) )\n",
    "    \n",
    "    sig = model.likelihood.noise.detach().cpu().numpy()\n",
    "    MAE_sig = np.abs(sig - sig_true.detach().cpu().numpy() )/ sig_true.detach().cpu().numpy()\n",
    "    print('sig:', sig, 'sig_true:', sig_true[0])\n",
    "    MAE_sig = MAE_sig[0]\n",
    "    \n",
    "    prec_in_1std = np.mean( np.abs(posterior_mean - ys_test) < posterior_std )\n",
    "    prec_in_2std = np.mean( np.abs(posterior_mean - ys_test) < 2*posterior_std )\n",
    "\n",
    "    print(f'MAE = {MAE:.4f}')\n",
    "    print(f'MAE_sig = {MAE_sig:.4f}')\n",
    "\n",
    "    print(f'NSE = ', np.linalg.norm(posterior_mean - ys_test)**2 / np.linalg.norm(ys_test)**2. ) \n",
    "\n",
    "\n",
    "    print(prec_in_1std, prec_in_2std)\n",
    "    \n",
    "    return MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, MAE_sig\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': [], 'MAE_sig': []}\n",
    "for i in range(N_test):\n",
    "    print('Running prediction for test data', i+1, '/', N_test)\n",
    "    MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, MAE_sig = run_pred(data_test_list[i])\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    pred_stats['MAE_sig'].append(MAE_sig)  # MRE of the lengthscale\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "pred_stats['MAE_sig'] = np.array(pred_stats['MAE_sig'])\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"MAE_sig: \", np.mean(pred_stats['MAE_sig']), \"±\", np.std(pred_stats['MAE_sig']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': [], 'MAE_sig': []}\n",
    "for i in range(N_test):\n",
    "    MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, MAE_sig = run_pred(data_test_list[i], nu=np.inf)\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    pred_stats['MAE_sig'].append(MAE_sig)  # MRE of the lengthscale\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "pred_stats['MAE_sig'] = np.array(pred_stats['MAE_sig'])\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"MAE_sig: \", np.mean(pred_stats['MAE_sig']), \"±\", np.std(pred_stats['MAE_sig']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_stats = {'MAE': [], 'MaxAE': [], 'prec_in_1std': [], 'prec_in_2std': [], 'time_taken': [], 'MAE_sig': []}\n",
    "for i in range(N_test):\n",
    "    MAE, MaxAE, prec_in_1std, prec_in_2std, time_taken, MAE_sig = run_pred(data_test_list[i], nu=0.5)\n",
    "    pred_stats['MAE'].append(MAE)\n",
    "    pred_stats['MaxAE'].append(MaxAE)\n",
    "    pred_stats['prec_in_1std'].append(prec_in_1std)\n",
    "    pred_stats['prec_in_2std'].append(prec_in_2std)\n",
    "    pred_stats['time_taken'].append(time_taken)\n",
    "    pred_stats['MAE_sig'].append(MAE_sig)  # MRE of the lengthscale\n",
    "    \n",
    "pred_stats['MAE'] = np.array(pred_stats['MAE'])\n",
    "pred_stats['MaxAE'] = np.array(pred_stats['MaxAE'])\n",
    "pred_stats['prec_in_1std'] = np.array(pred_stats['prec_in_1std'])\n",
    "pred_stats['prec_in_2std'] = np.array(pred_stats['prec_in_2std'])\n",
    "pred_stats['time_taken'] = np.array(pred_stats['time_taken'])\n",
    "pred_stats['MAE_sig'] = np.array(pred_stats['MAE_sig'])\n",
    "\n",
    "print(\"MAE: \", np.mean(pred_stats['MAE']), \"±\", np.std(pred_stats['MAE']))\n",
    "print(\"MaxAE: \", np.mean(pred_stats['MaxAE']), \"±\", np.std(pred_stats['MaxAE']))\n",
    "print(\"prec_in_1std: \", np.mean(pred_stats['prec_in_1std']), \"±\", np.std(pred_stats['prec_in_1std']))\n",
    "print(\"prec_in_2std: \", np.mean(pred_stats['prec_in_2std']), \"±\", np.std(pred_stats['prec_in_2std']))\n",
    "print(\"MAE_sig: \", np.mean(pred_stats['MAE_sig']), \"±\", np.std(pred_stats['MAE_sig']))\n",
    "print(\"time_taken: \", np.mean(pred_stats['time_taken']), \"±\", np.std(pred_stats['time_taken']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
