{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "73c5b218-417e-4715-9161-bcc535602451",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import logging\n",
    "import math # Import math\n",
    "import numpy as np\n",
    "import torch\n",
    "import gpytorch\n",
    "from scipy.stats import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "28959c99-ea80-4bb9-9876-fb90423f004e",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"%(asctime)s %(levelname)s %(message)s\",\n",
    "    datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
    ")\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ced7a4a-56f5-4765-9b57-3dc13de9cc92",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-22 17:53:54 INFO Loading MNIST data …\n",
      "2025-04-22 17:53:56 INFO Loading pre‑computed L1 distance matrix …\n"
     ]
    }
   ],
   "source": [
    "base_path = '/global/u1/v/vtek/MNIST'\n",
    "base_path_dist = '/pscratch/sd/v/vtek'\n",
    "\n",
    "logger.info(\"Loading MNIST data …\")\n",
    "with open(os.path.join(base_path, \"x_train_MNIST.pkl\"), \"rb\") as f:\n",
    "    X_train = np.array(pickle.load(f), dtype=np.float32)\n",
    "with open(os.path.join(base_path, \"y_train_MNIST.pkl\"), \"rb\") as f:\n",
    "    y_train = np.array(pickle.load(f), dtype=np.float32).flatten()\n",
    "with open(os.path.join(base_path, \"x_test_MNIST.pkl\"), \"rb\") as f:\n",
    "    X_test = np.array(pickle.load(f), dtype=np.float32)\n",
    "with open(os.path.join(base_path, \"y_test_MNIST.pkl\"), \"rb\") as f:\n",
    "    y_test = np.array(pickle.load(f), dtype=np.float32).flatten()\n",
    "\n",
    "logger.info(\"Loading pre‑computed L1 distance matrix …\")\n",
    "l1 = np.load(os.path.join(base_path_dist, \"distance_l1.npy\"))\n",
    "\n",
    "n_train = X_train.shape[0]\n",
    "n_test  = X_test.shape[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2b288986-f2cc-4572-9d3c-8495a78772dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "train_idx = torch.arange(n_train,                      dtype=torch.float32, device=device).unsqueeze(1)\n",
    "test_idx  = torch.arange(n_train, n_train + n_test,    dtype=torch.float32, device=device).unsqueeze(1)\n",
    "\n",
    "y_train_t = torch.from_numpy(y_train).to(device)\n",
    "y_test_t  = torch.from_numpy(y_test).to(device)\n",
    "l1_t      = torch.from_numpy(l1).float().to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "85393408-4af2-451a-a99d-32b3f32076f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrecomputedExpKernel(gpytorch.kernels.Kernel):\n",
    "    is_stationary = True\n",
    "    def __init__(self, l1_full):\n",
    "        super().__init__()\n",
    "        self.register_buffer(\"l1_full\", l1_full)\n",
    "        self.log_lengthscale = torch.nn.Parameter(torch.zeros(1))\n",
    "\n",
    "    def forward(self, x1, x2=None, diag=False, **params):\n",
    "        if diag:\n",
    "            return torch.ones(x1.shape[:-1], device=x1.device)\n",
    "        if x2 is None:\n",
    "            x2 = x1\n",
    "        i1 = x1.squeeze(-1).long()\n",
    "        i2 = x2.squeeze(-1).long()\n",
    "        d = self.l1_full.index_select(0, i1).index_select(1, i2)\n",
    "        #d = torch.sqrt(d2 + 1e-12)\n",
    "        l = self.log_lengthscale.exp()\n",
    "        return torch.exp(-d / l)\n",
    "\n",
    "\n",
    "class SVGPModel(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_idx, l1_full):\n",
    "        q_dist   = gpytorch.variational.CholeskyVariationalDistribution(inducing_idx.size(0))\n",
    "        q_strat  = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_idx, q_dist, learn_inducing_locations=False\n",
    "        )\n",
    "        super().__init__(q_strat)\n",
    "        self.mean_module  = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(\n",
    "            PrecomputedExpKernel(l1_full))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return gpytorch.distributions.MultivariateNormal(\n",
    "            self.mean_module(x), self.covar_module(x)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8a162afb-b9c3-4a12-a377-3ee51c6cbcd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inducing = 500\n",
    "perm          = torch.randperm(n_train, device=device)\n",
    "inducing_idx  = train_idx[perm[:num_inducing]].clone()\n",
    "\n",
    "model      = SVGPModel(inducing_idx, l1_t).to(device)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)\n",
    "\n",
    "model.train(); likelihood.train()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=.2)\n",
    "elbo      = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=n_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "e173097d-0778-4ffb-b708-f3ca542ef06d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-22 19:56:18 INFO Beginning SVGP training …\n",
      "2025-04-22 19:56:29 INFO Epoch 20/200  |  ELBO: -1.7774\n",
      "2025-04-22 19:56:39 INFO Epoch 40/200  |  ELBO: -1.8076\n",
      "2025-04-22 19:56:50 INFO Epoch 60/200  |  ELBO: -1.7760\n",
      "2025-04-22 19:57:03 INFO Epoch 80/200  |  ELBO: -1.7829\n",
      "2025-04-22 19:57:15 INFO Epoch 100/200  |  ELBO: -1.7835\n",
      "2025-04-22 19:57:26 INFO Epoch 120/200  |  ELBO: -1.7751\n",
      "2025-04-22 19:57:36 INFO Epoch 140/200  |  ELBO: -1.7748\n",
      "2025-04-22 19:57:47 INFO Epoch 160/200  |  ELBO: -1.7918\n",
      "2025-04-22 19:57:59 INFO Epoch 180/200  |  ELBO: -1.7758\n",
      "2025-04-22 19:58:11 INFO Epoch 200/200  |  ELBO: -1.7761\n",
      "2025-04-22 19:58:11 INFO Finished training in 112.9s\n"
     ]
    }
   ],
   "source": [
    "epochs, log_every = 1000, 20\n",
    "logger.info(\"Beginning SVGP training …\")\n",
    "start = time.time()\n",
    "for epoch in range(1, epochs + 1):\n",
    "    optimizer.zero_grad()\n",
    "    loss = -elbo(model(train_idx), y_train_t)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    if epoch % log_every == 0:\n",
    "        logger.info(f\"Epoch {epoch}/{epochs}  |  ELBO: {-loss.item():.4f}\")\n",
    "logger.info(f\"Finished training in {time.time() - start:.1f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "e7897b6f-d017-4866-bf91-9228303d9bea",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval(); likelihood.eval()\n",
    "with torch.no_grad():\n",
    "    preds = likelihood(model(test_idx))\n",
    "    mean  = preds.mean\n",
    "    std   = preds.stddev\n",
    "\n",
    "rmse = torch.sqrt(torch.mean((mean - y_test_t) ** 2)).item()\n",
    "\n",
    "def crps_gaussian(y, mu, sigma):\n",
    "    z = (y - mu) / sigma\n",
    "    return sigma * (z * (2 * norm.cdf(z) - 1) + 2 * norm.pdf(z) - 1 / np.sqrt(np.pi))\n",
    "\n",
    "mean_crps = np.mean(crps_gaussian(y_test, mean.cpu().numpy(), std.cpu().numpy()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "663d828b-55f2-49f3-bb90-cbc66bf99acc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 1.1912994384765625\n",
      "Mean CRPS: 0.8605195031334578\n"
     ]
    }
   ],
   "source": [
    "print(f\"RMSE: {rmse}\")\n",
    "print(f\"Mean CRPS: {mean_crps}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de19790c-46ff-4857-9294-1ba808445bf1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0c878a1-8f9c-413a-8dc9-52d5a9d6670c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpcam_env",
   "language": "python",
   "name": "gpcam_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
