{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vv1gJ53dEnMw"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from sklearn.neighbors import NearestNeighbors\n",
        "from sklearn.gaussian_process import GaussianProcessRegressor\n",
        "from sklearn.gaussian_process.kernels import Matern, WhiteKernel\n",
        "import matplotlib.pyplot as plt\n",
        "import gpytorch\n",
        "import torch\n",
        "from scipy.stats import norm"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "base_path = \"/content/drive/MyDrive/Experiments/1d_function/\"\n",
        "\n",
        "X_train = pd.read_csv(base_path + \"x_train_1d.csv\", header=None).values  # Convert to NumPy array\n",
        "y_train = pd.read_csv(base_path + \"y_train_1d.csv\", header=None).values.flatten()\n",
        "\n",
        "X_test = pd.read_csv(base_path + \"x_test_1d.csv\", header=None).values\n",
        "y_test = pd.read_csv(base_path + \"y_test_1d.csv\", header=None).values.flatten()\n",
        "\n",
        "X_train_tensor = torch.from_numpy(X_train).float()\n",
        "y_train_tensor = torch.from_numpy(y_train).float()\n",
        "X_test_tensor = torch.from_numpy(X_test).float()\n",
        "y_test_tensor = torch.from_numpy(y_test).float()"
      ],
      "metadata": {
        "id": "WydAjjLMHAVU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plt.figure(figsize=(10, 6))\n",
        "plt.scatter(X_train, y_train, color='blue', label=\"Training Data\", alpha=0.5, s=10)\n",
        "plt.scatter(X_test, y_test, color='red', label=\"Test Data\", alpha=0.5, s=10)\n",
        "plt.xlabel(\"X\")\n",
        "plt.ylabel(\"Y\")\n",
        "plt.title(\"Training and Test Data Scatter Plot\")\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "IVEWsnVjG00N"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SKIGPModel(gpytorch.models.ExactGP):\n",
        "    def __init__(self, train_x, train_y, likelihood):\n",
        "        super(SKIGPModel, self).__init__(train_x, train_y, likelihood)\n",
        "\n",
        "        # Base kernel (Matern, RBF, etc.)\n",
        "        base_kernel = gpytorch.kernels.MaternKernel(nu=1.5)\n",
        "\n",
        "        # GridInterpolationKernel for SKI\n",
        "        # -- We specify a grid_size. In 1D, this is just one integer,\n",
        "        #    but in higher dimensions, it can be a list of integers.\n",
        "        self.kern = gpytorch.kernels.ScaleKernel(\n",
        "            gpytorch.kernels.GridInterpolationKernel(\n",
        "                base_kernel,\n",
        "                grid_size=20,    # adjust if needed\n",
        "                num_dims=1        # we have 1D input\n",
        "            )\n",
        "        )\n",
        "\n",
        "        self.mean_module = gpytorch.means.ConstantMean()\n",
        "        self.likelihood = likelihood\n",
        "\n",
        "    def forward(self, x):\n",
        "        mean = self.mean_module(x)\n",
        "        cov = self.kern(x)\n",
        "        return gpytorch.distributions.MultivariateNormal(mean, cov)\n",
        "\n",
        "\n",
        "\n",
        "# Initialize model and likelihood\n",
        "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
        "model = SKIGPModel(X_train_tensor, y_train_tensor, likelihood)"
      ],
      "metadata": {
        "id": "TYg0XnMSGIom"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Training mode\n",
        "model.train()\n",
        "likelihood.train()\n",
        "\n",
        "# Use the GPU if available\n",
        "if torch.cuda.is_available():\n",
        "    model = model.cuda()\n",
        "    likelihood = likelihood.cuda()\n",
        "    X_train_tensor = X_train_tensor.cuda()\n",
        "    y_train_tensor = y_train_tensor.cuda()\n",
        "    X_test_tensor = X_test_tensor.cuda()\n",
        "    y_test_tensor = y_test_tensor.cuda()\n",
        "\n",
        "# Define optimizer and loss function\n",
        "optimizer = torch.optim.Adam([\n",
        "    {'params': model.parameters()},\n",
        "], lr=0.1)\n",
        "\n",
        "mll = mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
        "\n",
        "# Training loop\n",
        "def train(num_epochs=10):\n",
        "    losses = []\n",
        "    for i in range(num_epochs):\n",
        "        optimizer.zero_grad()\n",
        "        output = model(X_train_tensor)\n",
        "        loss = -mll(output, y_train_tensor)\n",
        "        loss.backward()\n",
        "        losses.append(loss.item())\n",
        "        optimizer.step()\n",
        "\n",
        "        if (i+1) % 10 == 0:\n",
        "            print(f'Epoch {i+1}/{num_epochs} - Loss: {loss.item():.3f}')\n",
        "    return losses"
      ],
      "metadata": {
        "id": "UlygodYLMFuA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to make predictions\n",
        "def predict(x_new):\n",
        "    model.eval()\n",
        "    likelihood.eval()\n",
        "    with torch.no_grad():\n",
        "        pred_dist = likelihood(model(x_new))\n",
        "        mean = pred_dist.mean\n",
        "        lower, upper = pred_dist.confidence_region()\n",
        "    return mean, lower, upper\n"
      ],
      "metadata": {
        "id": "C2UZh7d4MFv-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_results(x_train, y_train, x_test, y_test, mean, lower, upper):\n",
        "    plt.figure(figsize=(12, 6))\n",
        "\n",
        "    # Plot training data\n",
        "    plt.scatter(x_train.cpu().numpy(), y_train.cpu().numpy(), color='blue', label='Training Data', alpha=0.5)\n",
        "    plt.scatter(x_test.cpu().numpy(), y_test.cpu().numpy(), color='red', label='Test Data', alpha=0.5)\n",
        "\n",
        "    # Sort test points for smooth plotting\n",
        "    x_test_np = x_test.cpu().numpy().flatten()\n",
        "    mean_np = mean.cpu().numpy().flatten()\n",
        "    lower_np = lower.cpu().numpy().flatten()\n",
        "    upper_np = upper.cpu().numpy().flatten()\n",
        "\n",
        "    sort_idx = np.argsort(x_test_np)\n",
        "    x_test_sorted = x_test_np[sort_idx]\n",
        "    mean_sorted = mean_np[sort_idx]\n",
        "    lower_sorted = lower_np[sort_idx]\n",
        "    upper_sorted = upper_np[sort_idx]\n",
        "\n",
        "    # Plot sorted predictions\n",
        "    plt.plot(x_test_sorted, mean_sorted, 'k', label='Predicted Mean', linewidth=2)\n",
        "    plt.fill_between(x_test_sorted, lower_sorted, upper_sorted,\n",
        "                    alpha=0.2, color='k', label='95% Confidence')\n",
        "\n",
        "    plt.xlabel('X')\n",
        "    plt.ylabel('Y')\n",
        "    plt.title('SKI Predictions')\n",
        "    plt.legend()\n",
        "    plt.show()"
      ],
      "metadata": {
        "id": "nDNffQwAMFyE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Train the model\n",
        "num_epochs = 100\n",
        "losses = train(num_epochs)"
      ],
      "metadata": {
        "id": "pNqc8g_0MF0I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "epochs = list(range(1, num_epochs+1))\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.plot(epochs, losses, marker='o', linestyle='-')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Training Loss per Epoch')\n",
        "plt.grid(True)\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "G4I3xhrmg2S5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Make predictions\n",
        "mean, lower, upper = predict(X_test_tensor)\n",
        "\n",
        "# Plot results\n",
        "plot_results(X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor, mean, lower, upper)"
      ],
      "metadata": {
        "id": "EQ4HI9lwMF2W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_dist = likelihood(model(X_test_tensor))\n",
        "mean = pred_dist.mean\n",
        "std = pred_dist.stddev\n",
        "\n",
        "mean_np = mean.detach().cpu().numpy()\n",
        "std_np = std.detach().cpu().numpy()\n",
        "\n",
        "# Get learned noise variance\n",
        "noise_variance = likelihood.noise.detach().cpu().numpy()\n",
        "print(f\"Noise Variance: {noise_variance}\")"
      ],
      "metadata": {
        "id": "VDwtsEy-eIQj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def crps_gaussian(y, mu, sigma):\n",
        "    z = (y - mu) / sigma\n",
        "    crps = sigma * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))\n",
        "    return crps\n",
        "\n",
        "rmse = np.sqrt(np.mean((mean_np - y_test)**2))\n",
        "print(f\"RMSE: {rmse}\")\n",
        "\n",
        "crps_values = crps_gaussian(y_test, mean_np, std_np)\n",
        "mean_crps = np.mean(crps_values)\n",
        "print(f\"Mean CRPS: {mean_crps}\")"
      ],
      "metadata": {
        "id": "9EYtiZxPcShy"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}