{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6dKydLbZSTtB"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from sklearn.neighbors import NearestNeighbors\n",
        "from sklearn.gaussian_process import GaussianProcessRegressor\n",
        "from sklearn.gaussian_process.kernels import Matern, WhiteKernel\n",
        "from scipy.stats import norm\n",
        "import plotly.graph_objects as go"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "CfKZq2gAryr7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "base_path = \"/content/drive/MyDrive/Experiments/2d_topo/\"\n",
        "\n",
        "X_train = pd.read_csv(base_path + \"x_train_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32)\n",
        "y_train = pd.read_csv(base_path + \"y_train_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32).flatten()\n",
        "X_test = pd.read_csv(base_path + \"x_test_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32)\n",
        "y_test = pd.read_csv(base_path + \"y_test_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32).flatten()"
      ],
      "metadata": {
        "id": "YGj-LyHHSlud"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def nearest_neighbors_gp_predict_2d(X_train, y_train, X_test, k=50):\n",
        "    \"\"\"\n",
        "    For each test point in X_test, find the k nearest training points,\n",
        "    fit a local Gaussian Process with a Matern + WhiteKernel,\n",
        "    and return the predicted mean and standard deviation.\n",
        "    \"\"\"\n",
        "    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_train)\n",
        "    y_pred = np.zeros(len(X_test))\n",
        "    y_std = np.zeros(len(X_test))\n",
        "\n",
        "    for i, x_star in enumerate(X_test):\n",
        "        # Find indices of k nearest neighbors\n",
        "        distances, indices = nbrs.kneighbors([x_star])\n",
        "        X_sub = X_train[indices[0]]\n",
        "        y_sub = y_train[indices[0]]\n",
        "\n",
        "        # Define kernel (using a Matern kernel with nu=1.5 for smoothness)\n",
        "        # and add a WhiteKernel to account for noise.\n",
        "        kernel = Matern(length_scale=1.0, nu=1.5) + WhiteKernel(noise_level=0.04)\n",
        "        gp = GaussianProcessRegressor(kernel=kernel, alpha=0.0, optimizer=None, normalize_y=True)\n",
        "        gp.fit(X_sub, y_sub)\n",
        "\n",
        "        # Predict for the test point (reshape as 1 x 2)\n",
        "        mu, std = gp.predict(x_star.reshape(1, -1), return_std=True)\n",
        "        y_pred[i] = mu[0]\n",
        "        y_std[i] = std[0]\n",
        "\n",
        "    return y_pred, y_std\n"
      ],
      "metadata": {
        "id": "9KwynsK7TTWp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set the number of neighbors (experiment with this value)\n",
        "k_neighbors = 25\n",
        "y_pred_local, y_std_local = nearest_neighbors_gp_predict_2d(X_train, y_train, X_test, k=k_neighbors)"
      ],
      "metadata": {
        "id": "q6kX-tyWTaST"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig = go.Figure()\n",
        "\n",
        "# True test data\n",
        "fig.add_trace(go.Scatter3d(\n",
        "    x=X_test[:, 0],\n",
        "    y=X_test[:, 1],\n",
        "    z=y_test,\n",
        "    mode='markers',\n",
        "    name='Test Data (True)',\n",
        "    marker=dict(size=4, color='blue', opacity=0.6),\n",
        "    hovertemplate='X: %{x}<br>Y: %{y}<br>True Density: %{z}<extra></extra>'\n",
        "))\n",
        "\n",
        "# Predicted test data\n",
        "fig.add_trace(go.Scatter3d(\n",
        "    x=X_test[:, 0],\n",
        "    y=X_test[:, 1],\n",
        "    z=y_pred_local,\n",
        "    mode='markers',\n",
        "    name='Test Data (Predicted)',\n",
        "    marker=dict(size=4, color='red', opacity=0.6),\n",
        "    hovertemplate='X: %{x}<br>Y: %{y}<br>Predicted Density: %{z}<extra></extra>'\n",
        "))\n",
        "\n",
        "fig.update_layout(\n",
        "    title=f'Nearest Neighbors GP Predictions on 2D Topological Data (k={k_neighbors})',\n",
        "    scene=dict(\n",
        "        xaxis_title='X coordinate',\n",
        "        yaxis_title='Y coordinate',\n",
        "        zaxis_title='Density',\n",
        "        camera=dict(\n",
        "            up=dict(x=0, y=0, z=1),\n",
        "            center=dict(x=0, y=0, z=0),\n",
        "            eye=dict(x=1.5, y=1.5, z=1.5)\n",
        "        )\n",
        "    ),\n",
        "    width=1000,\n",
        "    height=800,\n",
        "    margin=dict(l=0, r=0, b=0, t=30)\n",
        ")\n",
        "\n",
        "fig.show()"
      ],
      "metadata": {
        "id": "mYvSVcvITeXI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def crps_gaussian(y, mu, sigma):\n",
        "    z = (y - mu) / sigma\n",
        "    crps = sigma * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))\n",
        "    return crps\n",
        "\n",
        "rmse = np.sqrt(np.mean((y_pred_local - y_test)**2))\n",
        "print(f\"RMSE: {rmse}\")\n",
        "\n",
        "crps_values = crps_gaussian(y_test, y_pred_local, y_std_local)\n",
        "mean_crps = np.mean(crps_values)\n",
        "print(f\"Mean CRPS: {mean_crps}\")\n"
      ],
      "metadata": {
        "id": "PNWYqzuOTiV0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "fRO27rgwV4ei"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}