{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vv1gJ53dEnMw"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "from sklearn.gaussian_process.kernels import Matern, WhiteKernel\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WydAjjLMHAVU"
   },
   "outputs": [],
   "source": [
    "base_path = \"/content/drive/MyDrive/Experiments/1d_function/\" # change for your drive\n",
    "\n",
    "X_train = pd.read_csv(base_path + \"x_train_1d.csv\", header=None).values  # Convert to NumPy array\n",
    "y_train = pd.read_csv(base_path + \"y_train_1d.csv\", header=None).values.flatten()\n",
    "\n",
    "X_test = pd.read_csv(base_path + \"x_test_1d.csv\", header=None).values\n",
    "y_test = pd.read_csv(base_path + \"y_test_1d.csv\", header=None).values.flatten()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IVEWsnVjG00N"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "plt.scatter(X_train, y_train, color='blue', label=\"Training Data\", alpha=0.5, s=10)\n",
    "plt.scatter(X_test, y_test, color='red', label=\"Test Data\", alpha=0.5, s=10)\n",
    "plt.xlabel(\"X\")\n",
    "plt.ylabel(\"Y\")\n",
    "plt.title(\"Training and Test Data Scatter Plot\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TYg0XnMSGIom"
   },
   "outputs": [],
   "source": [
    "def nearest_neighbors_gp_predict(X_train, y_train, X_test, k=50):\n",
    "    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(X_train)\n",
    "    y_pred = np.zeros(len(X_test))\n",
    "    y_std = np.zeros(len(X_test))\n",
    "\n",
    "    for i, x_star in enumerate(X_test):\n",
    "        distances, indices = nbrs.kneighbors([x_star])\n",
    "        X_sub = X_train[indices[0]]\n",
    "        y_sub = y_train[indices[0]]\n",
    "\n",
    "        # Use WhiteKernel(noise_level=0.04) assuming similar noise structure\n",
    "        kernel = Matern(length_scale=1.0) + WhiteKernel(noise_level=0.04) #define noise to be 0.04\n",
    "        gp = GaussianProcessRegressor(kernel=kernel, alpha=0.0, optimizer=None)\n",
    "        gp.fit(X_sub, y_sub)\n",
    "\n",
    "        # Extract scalars explicitly\n",
    "        mu, std = gp.predict(x_star.reshape(1, -1), return_std=True)\n",
    "        y_pred[i] = mu[0]\n",
    "        y_std[i] = std[0]\n",
    "\n",
    "    return y_pred, y_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j-Fai9riGbRE"
   },
   "outputs": [],
   "source": [
    "k_neighbors = 10\n",
    "y_pred_local, y_std_local = nearest_neighbors_gp_predict(X_train, y_train, X_test, k=k_neighbors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BEMbpFY6HHyL"
   },
   "outputs": [],
   "source": [
    "# Sort the test data and corresponding predictions for smooth plotting\n",
    "test_sorted_indices = np.argsort(X_test.flatten())\n",
    "X_test_sorted = X_test[test_sorted_indices]\n",
    "y_test_sorted = y_test[test_sorted_indices]\n",
    "y_pred_sorted = y_pred_local[test_sorted_indices]\n",
    "y_std_sorted = y_std_local[test_sorted_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4zzf2kaHHwmq"
   },
   "outputs": [],
   "source": [
    "# Plot the results with test values as discrete points\n",
    "plt.figure(figsize=(12, 6))\n",
    "\n",
    "# Training data scatter plot\n",
    "plt.scatter(X_train, y_train, s=10, alpha=0.5, label=\"Training Data\", color = 'blue')\n",
    "\n",
    "# True test values as discrete points\n",
    "plt.scatter(X_test_sorted, y_test_sorted, color='red', marker='o', s=10, alpha=0.5, label=\"Test Data\")\n",
    "\n",
    "# Continuous GP posterior mean function\n",
    "plt.plot(X_test_sorted, y_pred_sorted, 'black', label=\"Local GP Prediction\", linewidth=2)\n",
    "\n",
    "# Confidence interval\n",
    "plt.fill_between(X_test_sorted.flatten(),\n",
    "                 y_pred_sorted - 1.96 * y_std_sorted,\n",
    "                 y_pred_sorted + 1.96 * y_std_sorted,\n",
    "                 color='grey', alpha=0.3, label=\"Local GP 95% CI\")\n",
    "\n",
    "plt.xlabel(\"X\")\n",
    "plt.ylabel(\"Y\")\n",
    "plt.title(f\"Nearest Neighbors GP Predictions (k={k_neighbors})\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-WAEDJroIG7G"
   },
   "outputs": [],
   "source": [
    "def crps_gaussian(y, mu, sigma):\n",
    "    z = (y - mu) / sigma\n",
    "    crps = sigma * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))\n",
    "    return crps\n",
    "\n",
    "rmse = np.sqrt(np.mean((y_pred_local - y_test)**2))\n",
    "print(f\"RMSE: {rmse}\")\n",
    "\n",
    "crps_values = crps_gaussian(y_test, y_pred_local, y_std_local)\n",
    "mean_crps = np.mean(crps_values)\n",
    "print(f\"Mean CRPS: {mean_crps}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SZiWI-mpb2t1"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "NERSC Python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
