{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import gpytorch\n",
        "import torch\n",
        "import plotly.graph_objects as go\n",
        "from scipy.stats import norm\n"
      ],
      "metadata": {
        "id": "XsQyAQ_MK9gT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "base_path = \"/content/drive/MyDrive/Experiments/2d_topo/\"\n",
        "\n",
        "X_train = pd.read_csv(base_path + \"x_train_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32)\n",
        "y_train = pd.read_csv(base_path + \"y_train_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32).flatten()\n",
        "X_test = pd.read_csv(base_path + \"x_test_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32)\n",
        "y_test = pd.read_csv(base_path + \"y_test_2dtopo.csv\", header=None, sep='\\s+').values.astype(np.float32).flatten()\n",
        "\n",
        "X_train_tensor = torch.from_numpy(X_train).float()\n",
        "y_train_tensor = torch.from_numpy(y_train).float()\n",
        "X_test_tensor = torch.from_numpy(X_test).float()\n",
        "y_test_tensor = torch.from_numpy(y_test).float()\n"
      ],
      "metadata": {
        "id": "Vd2IVGQ7K9eP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create dataframes for training and test data\n",
        "df_train = pd.DataFrame({\n",
        "    'X': X_train[:, 0],\n",
        "    'Y': X_train[:, 1],\n",
        "    'Density': y_train,\n",
        "    'Type': ['Training'] * len(y_train)\n",
        "})\n",
        "\n",
        "df_test = pd.DataFrame({\n",
        "    'X': X_test[:, 0],\n",
        "    'Y': X_test[:, 1],\n",
        "    'Density': y_test,\n",
        "    'Type': ['Test'] * len(y_test)\n",
        "})\n",
        "\n",
        "# Create the interactive 3D scatter plot\n",
        "fig = go.Figure()\n",
        "\n",
        "# Add training data (blue)\n",
        "fig.add_trace(go.Scatter3d(\n",
        "    x=df_train['X'],\n",
        "    y=df_train['Y'],\n",
        "    z=df_train['Density'],\n",
        "    mode='markers',\n",
        "    name='Training Data',\n",
        "    marker=dict(\n",
        "        size=4,\n",
        "        color='blue',\n",
        "        opacity=0.6\n",
        "    ),\n",
        "    hovertemplate='X: %{x}<br>Y: %{y}<br>Density: %{z}<extra></extra>'\n",
        "))\n",
        "\n",
        "# Add test data (red)\n",
        "fig.add_trace(go.Scatter3d(\n",
        "    x=df_test['X'],\n",
        "    y=df_test['Y'],\n",
        "    z=df_test['Density'],\n",
        "    mode='markers',\n",
        "    name='Test Data',\n",
        "    marker=dict(\n",
        "        size=4,\n",
        "        color='red',\n",
        "        opacity=0.6\n",
        "    ),\n",
        "    hovertemplate='X: %{x}<br>Y: %{y}<br>Density: %{z}<extra></extra>'\n",
        "))\n",
        "\n",
        "# Update layout\n",
        "fig.update_layout(\n",
        "    title='Interactive 3D Scatter Plot of Training and Test Data',\n",
        "    scene=dict(\n",
        "        xaxis_title='X coordinate',\n",
        "        yaxis_title='Y coordinate',\n",
        "        zaxis_title='Density',\n",
        "        camera=dict(\n",
        "            up=dict(x=0, y=0, z=1),\n",
        "            center=dict(x=0, y=0, z=0),\n",
        "            eye=dict(x=1.5, y=1.5, z=1.5)\n",
        "        )\n",
        "    ),\n",
        "    width=1000,\n",
        "    height=800,\n",
        "    margin=dict(l=0, r=0, b=0, t=30)\n",
        ")\n",
        "\n",
        "# Show the plot\n",
        "fig.show()"
      ],
      "metadata": {
        "id": "-f4P55wxK9cW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class SVGP2DModel(gpytorch.models.ApproximateGP):\n",
        "    def __init__(self, inducing_points):\n",
        "        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(0))\n",
        "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
        "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
        "        )\n",
        "        super(SVGP2DModel, self).__init__(variational_strategy)\n",
        "        self.mean_module = gpytorch.means.ConstantMean()\n",
        "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=1.5))\n",
        "    def forward(self, x):\n",
        "        mean_x = self.mean_module(x)\n",
        "        covar_x = self.covar_module(x)\n",
        "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
        "\n",
        "num_inducing = 100\n",
        "perm = torch.randperm(X_train_tensor.size(0))\n",
        "inducing_points = X_train_tensor[perm[:num_inducing]].clone()\n",
        "\n",
        "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
        "model = SVGP2DModel(inducing_points)"
      ],
      "metadata": {
        "id": "4Hu26MlvLqRO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model.train()\n",
        "likelihood.train()\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    model = model.cuda()\n",
        "    likelihood = likelihood.cuda()\n",
        "    X_train_tensor = X_train_tensor.cuda()\n",
        "    y_train_tensor = y_train_tensor.cuda()\n",
        "    X_test_tensor = X_test_tensor.cuda()\n",
        "    y_test_tensor = y_test_tensor.cuda()\n",
        "    inducing_points = inducing_points.cuda()\n",
        "\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
        "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=X_train_tensor.size(0))\n",
        "\n",
        "def train_svgp(num_epochs=100):\n",
        "    for i in range(num_epochs):\n",
        "        optimizer.zero_grad()\n",
        "        output = model(X_train_tensor)\n",
        "        loss = -mll(output, y_train_tensor)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        if (i + 1) % 10 == 0:\n",
        "            print(f'Epoch {i + 1}/{num_epochs} - Loss: {loss.item():.3f}')\n"
      ],
      "metadata": {
        "id": "nt12MgYeLvhW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def predict_svgp(x_new):\n",
        "    model.eval()\n",
        "    likelihood.eval()\n",
        "    with torch.no_grad():\n",
        "        pred_dist = likelihood(model(x_new))\n",
        "        mean = pred_dist.mean\n",
        "        std = pred_dist.stddev\n",
        "    return mean, std\n",
        "\n",
        "def plot_results(x_train, y_train, x_test, y_test, mean):\n",
        "    fig = go.Figure()\n",
        "    fig.add_trace(go.Scatter3d(\n",
        "        x=x_train[:, 0].cpu().numpy(),\n",
        "        y=x_train[:, 1].cpu().numpy(),\n",
        "        z=y_train.cpu().numpy(),\n",
        "        mode='markers',\n",
        "        name='Training Data',\n",
        "        marker=dict(size=4, color='blue', opacity=0.6),\n",
        "        hovertemplate='X: %{x}<br>Y: %{y}<br>Density: %{z}<extra></extra>'\n",
        "    ))\n",
        "    fig.add_trace(go.Scatter3d(\n",
        "        x=x_test[:, 0].cpu().numpy(),\n",
        "        y=x_test[:, 1].cpu().numpy(),\n",
        "        z=mean.cpu().numpy(),\n",
        "        mode='markers',\n",
        "        name='Predictions',\n",
        "        marker=dict(size=4, color='red', symbol='x', opacity=0.6),\n",
        "        hovertemplate='X: %{x}<br>Y: %{y}<br>Predicted Density: %{z}<extra></extra>'\n",
        "    ))\n",
        "    fig.update_layout(\n",
        "        title='SVP Predictions (3D)',\n",
        "        scene=dict(\n",
        "            xaxis_title='X coordinate',\n",
        "            yaxis_title='Y coordinate',\n",
        "            zaxis_title='Density',\n",
        "            camera=dict(\n",
        "                up=dict(x=0, y=0, z=1),\n",
        "                center=dict(x=0, y=0, z=0),\n",
        "                eye=dict(x=1.5, y=1.5, z=1.5)\n",
        "            )\n",
        "        ),\n",
        "        width=1000,\n",
        "        height=800,\n",
        "        margin=dict(l=0, r=0, b=0, t=30)\n",
        "    )\n",
        "    fig.show()\n"
      ],
      "metadata": {
        "id": "yAXJNi_qL6qR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_svgp(num_epochs=500)"
      ],
      "metadata": {
        "id": "2koizfFTL8yY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_mean_svgp, pred_std_svgp = predict_svgp(X_test_tensor)\n"
      ],
      "metadata": {
        "id": "_p1fIKV6MAyV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_results(x_test, y_test, mean):\n",
        "    fig = go.Figure()\n",
        "    fig.add_trace(go.Scatter3d(\n",
        "        x=x_test[:, 0].cpu().numpy(),\n",
        "        y=x_test[:, 1].cpu().numpy(),\n",
        "        z=y_test.cpu().numpy(),\n",
        "        mode='markers',\n",
        "        name='Test Data',\n",
        "        marker=dict(size=4, color='blue', opacity=0.6),\n",
        "        hovertemplate='X: %{x}<br>Y: %{y}<br>Density: %{z}<extra></extra>'\n",
        "    ))\n",
        "    fig.add_trace(go.Scatter3d(\n",
        "        x=x_test[:, 0].cpu().numpy(),\n",
        "        y=x_test[:, 1].cpu().numpy(),\n",
        "        z=mean.cpu().numpy(),\n",
        "        mode='markers',\n",
        "        name='Predictions',\n",
        "        marker=dict(size=4, color='red', opacity=0.6),\n",
        "        hovertemplate='X: %{x}<br>Y: %{y}<br>Predicted Density: %{z}<extra></extra>'\n",
        "    ))\n",
        "    fig.update_layout(\n",
        "        title='SVGP Predictions (3D)',\n",
        "        scene=dict(\n",
        "            xaxis_title='X coordinate',\n",
        "            yaxis_title='Y coordinate',\n",
        "            zaxis_title='Density',\n",
        "            camera=dict(\n",
        "                up=dict(x=0, y=0, z=1),\n",
        "                center=dict(x=0, y=0, z=0),\n",
        "                eye=dict(x=1.5, y=1.5, z=1.5)\n",
        "            )\n",
        "        ),\n",
        "        width=1000,\n",
        "        height=800,\n",
        "        margin=dict(l=0, r=0, b=0, t=30)\n",
        "    )\n",
        "    fig.show()\n"
      ],
      "metadata": {
        "id": "j5WSCJUNMhe6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(X_test_tensor, y_test_tensor, pred_mean_svgp)"
      ],
      "metadata": {
        "id": "hIPG0nDWNSxI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def crps_gaussian(y, mu, sigma):\n",
        "    z = (y - mu) / sigma\n",
        "    crps = sigma * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))\n",
        "    return crps\n",
        "\n",
        "rmse = np.sqrt(np.mean((pred_mean_svgp.cpu().numpy() - y_test)**2))\n",
        "print(f\"RMSE: {rmse}\")\n",
        "\n",
        "crps_values = crps_gaussian(y_test, pred_mean_svgp.cpu().numpy(), pred_std_svgp.cpu().numpy())\n",
        "mean_crps = np.mean(crps_values)\n",
        "print(f\"Mean CRPS: {mean_crps}\")"
      ],
      "metadata": {
        "id": "6OOFDlg-NVbo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-XPIJ2huW9Ny"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}