{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd976a67-54d7-433f-8a9c-7a1250b12ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import logging\n",
    "import math # Import math\n",
    "import numpy as np\n",
    "import torch\n",
    "import gpytorch\n",
    "from scipy.stats import norm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edb8bb57-15c5-42aa-8f26-c7c7a90e66e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"%(asctime)s %(levelname)s %(message)s\",\n",
    "    datefmt=\"%Y-%m-%d %H:%M:%S\",\n",
    ")\n",
    "logger = logging.getLogger(__name__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "968ba4ce-2f79-4f61-a758-2067347ba922",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = '/global/u1/v/vtek/MNIST'\n",
    "base_path_dist = '/pscratch/sd/v/vtek'\n",
    "\n",
    "logger.info(f\"Loading MNIST data from base path: {base_path} ...\")\n",
    "try:\n",
    "    with open(os.path.join(base_path, \"y_train_MNIST.pkl\"), \"rb\") as f:\n",
    "        y_train_np = np.array(pickle.load(f), dtype=np.float32).flatten()\n",
    "    with open(os.path.join(base_path, \"y_test_MNIST.pkl\"), \"rb\") as f:\n",
    "        y_test_np = np.array(pickle.load(f), dtype=np.float32).flatten()\n",
    "    logger.info(\"MNIST data loaded successfully.\")\n",
    "except FileNotFoundError:\n",
    "    logger.error(f\"Error: Data files not found in {base_path}. Please check the path.\")\n",
    "    exit()\n",
    "except Exception as e:\n",
    "    logger.error(f\"An error occurred during data loading: {e}\")\n",
    "    exit()\n",
    "\n",
    "logger.info(f\"Loading pre-computed L1 distance matrix from: {base_path_dist} ...\")\n",
    "try:\n",
    "    l1 = np.load(os.path.join(base_path_dist, \"distance_l1.npy\"))\n",
    "    logger.info(\"L1 distance matrix loaded successfully.\")\n",
    "except FileNotFoundError:\n",
    "     logger.error(f\"Error: L1 distance file not found in {base_path_dist}. Please check the path.\")\n",
    "     exit()\n",
    "except Exception as e:\n",
    "    logger.error(f\"An error occurred loading the L1 distance matrix: {e}\")\n",
    "    exit()\n",
    "\n",
    "\n",
    "n_train = y_train_np.shape[0]\n",
    "n_test = y_test_np.shape[0]\n",
    "n_total = n_train + n_test\n",
    "\n",
    "# Validate distance matrix shape\n",
    "if l1.shape != (n_total, n_total):\n",
    "     logger.error(f\"L1 distance matrix shape mismatch. Expected ({n_total},{n_total}), got {l1.shape}\")\n",
    "     exit()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "092cf47d-c570-4106-917b-d50048c1c1d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "logger.info(f\"Using device: {device}\")\n",
    "train_idx = torch.arange(n_train, dtype=torch.float32, device=device).unsqueeze(1)\n",
    "test_idx = torch.arange(n_train, n_total, dtype=torch.float32, device=device).unsqueeze(1)\n",
    "y_train_t = torch.from_numpy(y_train_np).float().contiguous().to(device)\n",
    "y_test_t = torch.from_numpy(y_test_np).float().contiguous().to(device)\n",
    "l1_t = torch.from_numpy(l1).float().contiguous().to(device)\n",
    "logger.info(\"Created index and target tensors.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ce94a12-8c03-4a95-a82d-362c0eeafeec",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class PrecomputedL1ExpKernel(gpytorch.kernels.Kernel):\n",
    "    \"\"\"\n",
    "    Kernel using precomputed L1 distances: K(i, j) = exp(-d_ij / lengthscale).\n",
    "    Assumes the input l1_full matrix contains DIRECT L1 distances.\n",
    "    \"\"\"\n",
    "    is_stationary = True # Although non-stationary w.r.t indices, GPyTorch needs this\n",
    "\n",
    "    def __init__(self, l1_full):\n",
    "        super().__init__()\n",
    "        if l1_full.dim() != 2:\n",
    "             raise ValueError(\"l1_full must be a 2D matrix\")\n",
    "        self.register_buffer(\"l1_full\", l1_full)\n",
    "        self.register_parameter(\n",
    "            name=\"raw_lengthscale\", parameter=torch.nn.Parameter(torch.zeros(1))\n",
    "        )\n",
    "        self.register_constraint(\"raw_lengthscale\", gpytorch.constraints.Positive())\n",
    "\n",
    "    @property\n",
    "    def lengthscale(self):\n",
    "        return self.raw_lengthscale_constraint.transform(self.raw_lengthscale)\n",
    "\n",
    "    @lengthscale.setter\n",
    "    def lengthscale(self, value):\n",
    "        self._set_lengthscale(value)\n",
    "\n",
    "    def _set_lengthscale(self, value):\n",
    "        if not torch.is_tensor(value):\n",
    "            value = torch.as_tensor(value).to(self.raw_lengthscale)\n",
    "        self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))\n",
    "\n",
    "    def forward(self, x1, x2=None, diag=False, **params):\n",
    "        if diag:\n",
    "            return torch.ones(*x1.shape[:-1], device=x1.device, dtype=x1.dtype)\n",
    "\n",
    "        if x2 is None:\n",
    "            x2 = x1\n",
    "        i1 = x1.squeeze(-1).long() # Shape: (*batch, n)\n",
    "        i2 = x2.squeeze(-1).long() # Shape: (*batch, m)\n",
    "\n",
    "        max_idx = self.l1_full.shape[0] - 1\n",
    "        i1 = torch.clamp(i1, 0, max_idx)\n",
    "        i2 = torch.clamp(i2, 0, max_idx)\n",
    "\n",
    "        try:\n",
    "            rows_selected = self.l1_full[i1] \n",
    "            i2_unsqueezed = i2.unsqueeze(-2)\n",
    "            index_shape = rows_selected.shape[:-1] + (i2_unsqueezed.shape[-1],)\n",
    "            i2_expanded = i2_unsqueezed.expand(index_shape) \n",
    "            d_l1 = torch.gather(rows_selected, dim=-1, index=i2_expanded) \n",
    "        except IndexError as e:\n",
    "             logger.error(f\"IndexError during distance lookup: i1 shape {i1.shape}, i2 shape {i2.shape}. Error: {e}\")\n",
    "             # Provide more context if possible\n",
    "             logger.error(f\"Max index allowed: {max_idx}\")\n",
    "             logger.error(f\"Max i1 requested: {i1.max()}, Max i2 requested: {i2.max()}\")\n",
    "             raise e\n",
    "        d = d_l1 #\n",
    "\n",
    "        # Clamp negative distances just in case (shouldn't happen with L1)\n",
    "        d = torch.clamp(d, min=0)\n",
    "        l = self.lengthscale + 1e-9 # Add epsilon for stability\n",
    "\n",
    "        return torch.exp(-d / l)\n",
    "\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood, l1_full):\n",
    "        super().__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(\n",
    "            PrecomputedL1ExpKernel(l1_full=l1_full)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74e6e1de-cc23-4e32-8c02-a2cf45a3fb11",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "initial_noise_value = 1e-5\n",
    "if not (1e-9 <= initial_noise_value <= 1e-4):\n",
    "     initial_noise_value = (1e-9 + 1e-4) / 2 \n",
    "     logger.warning(f\"Initial noise value was outside the new constraint interval. Resetting to {initial_noise_value:.1e}\")\n",
    "\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood(\n",
    "    noise_constraint=gpytorch.constraints.Interval(1e-9, 1e-4),\n",
    "    noise=torch.tensor([initial_noise_value]) # Set the initial noise value\n",
    ").to(device)\n",
    "model = ExactGPModel(train_idx, y_train_t, likelihood, l1_t).to(device)\n",
    "model.train()\n",
    "likelihood.train()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.15)\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07da207-c1e1-48ec-b745-012f949404f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs, log_every = 200, 2\n",
    "logger.info(f\"Beginning ExactGP hyperparameter training for {epochs} epochs...\")\n",
    "start = time.time()\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    optimizer.zero_grad()\n",
    "    output = model(train_idx\n",
    "    loss = -mll(output, y_train_t)\n",
    "\n",
    "    if torch.isnan(loss) or torch.isinf(loss):\n",
    "        logger.error(f\"Epoch {epoch}: Loss is NaN or Inf! Stopping training.\")\n",
    "        try:\n",
    "            lengthscale = model.covar_module.base_kernel.lengthscale.item()\n",
    "            outputscale = model.covar_module.outputscale.item()\n",
    "            noise = likelihood.noise.item()\n",
    "            logger.info(f\" Params at failure: LS={lengthscale:.3e}, OS={outputscale:.3e}, N={noise:.3e}\")\n",
    "        except AttributeError:\n",
    "            logger.info(\"Could not retrieve hyperparameters at failure.\")\n",
    "        break\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch % log_every == 0 or epoch == 1:\n",
    "        try:\n",
    "            lengthscale_val = model.covar_module.base_kernel.lengthscale.item()\n",
    "            outputscale_val = model.covar_module.outputscale.item()\n",
    "            noise_val = likelihood.noise.item()\n",
    "            logger.info(\n",
    "                f\"Epoch {epoch}/{epochs} | Loss (Neg MLL): {loss.item():.4f} | \"\n",
    "                f\"LS: {lengthscale_val:.3f} | OS: {outputscale_val:.3f} | Noise: {noise_val:.3f}\"\n",
    "            )\n",
    "        except AttributeError as e:\n",
    "             logger.warning(f\"Epoch {epoch}: Could not log hyperparameters - {e}\")\n",
    "\n",
    "\n",
    "logger.info(f\"Finished hyperparameter training in {time.time() - start:.1f}s\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6d9f6b1-97a1-41fa-bf6a-b758a87c2990",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "logger.info(\"Starting NNGP prediction...\")\n",
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "num_neighbors = 100\n",
    "logger.info(f\"Using {num_neighbors} nearest neighbors for prediction.\")\n",
    "\n",
    "nngp_means = torch.zeros(n_test, device=device)\n",
    "nngp_vars = torch.zeros(n_test, device=device)\n",
    "\n",
    "start_pred = time.time()\n",
    "log_interval = 100  \n",
    "\n",
    "for i in range(n_test):\n",
    "    if i % log_interval == 0:\n",
    "        logger.info(f\"Processing point {i}/{n_test} ({i/n_test*100:.1f}%)\")\n",
    "    \n",
    "    current_test_idx_scalar = n_train + i \n",
    "    current_test_idx_tensor = test_idx[i].unsqueeze(0) \n",
    "    dists_to_train = l1_t[current_test_idx_scalar, :n_train]\n",
    "    _, neighbor_indices = torch.topk(-dists_to_train, k=num_neighbors, largest=True)\n",
    "\n",
    "    neighbor_train_idx = train_idx[neighbor_indices] # Shape [k, 1]\n",
    "    neighbor_y_train = y_train_t[neighbor_indices]   # Shape [k]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        pred_model = ExactGPModel(neighbor_train_idx, neighbor_y_train, likelihood, l1_t)\n",
    "        pred_model.load_state_dict(model.state_dict())\n",
    "        pred_model.eval()\n",
    "        likelihood.eval()\n",
    "\n",
    "        predictive_dist = likelihood(pred_model(current_test_idx_tensor))\n",
    "        nngp_means[i] = predictive_dist.mean.squeeze()\n",
    "        nngp_vars[i] = predictive_dist.variance.squeeze()\n",
    "\n",
    "# Log completion\n",
    "logger.info(f\"Processed all {n_test} points\")\n",
    "logger.info(f\"Finished NNGP prediction in {time.time() - start_pred:.1f}s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c372fad6-12b7-4d26-846c-754177248fd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger.info(\"Calculating evaluation metrics...\")\n",
    "nngp_stddev = torch.sqrt(nngp_vars)\n",
    "rmse = torch.sqrt(torch.mean((nngp_means - y_test_t) ** 2)).item()\n",
    "def crps_gaussian(y, mu, sigma):\n",
    "    \"\"\"Calculate CRPS for Gaussian predictions.\"\"\"\n",
    "    y_np = y.cpu().numpy()\n",
    "    mu_np = mu.cpu().numpy()\n",
    "    sigma_np = sigma.cpu().numpy()\n",
    "    sigma_np = np.maximum(sigma_np, 1e-9) # Clamp sigma for stability\n",
    "    z = (y_np - mu_np) / sigma_np\n",
    "    term1 = z * (2 * norm.cdf(z) - 1)\n",
    "    term2 = 2 * norm.pdf(z)\n",
    "    term3 = 1 / np.sqrt(np.pi)\n",
    "    crps = sigma_np * (term1 + term2 - term3)\n",
    "    return crps\n",
    "\n",
    "crps_values = crps_gaussian(y_test_t, nngp_means, nngp_stddev)\n",
    "crps_values = np.nan_to_num(crps_values, nan=np.nanmean(crps_values)) # Handle potential NaNs\n",
    "mean_crps = np.mean(crps_values)\n",
    "\n",
    "predicted_classes = torch.round(nngp_means).long()\n",
    "predicted_classes = torch.clamp(predicted_classes, 0, 9) # Ensure valid class range\n",
    "correct_predictions = torch.sum(predicted_classes == y_test_t.long()).item()\n",
    "accuracy = correct_predictions / n_test\n",
    "\n",
    "print(f\"\\n--- Results (NNGP with Precomputed L1 Kernel) ---\")\n",
    "print(f\"Number of Neighbors (k): {num_neighbors}\")\n",
    "print(f\"RMSE (treating classes as regression targets): {rmse:.4f}\")\n",
    "print(f\"Mean CRPS (treating classes as regression targets): {mean_crps:.4f}\")\n",
    "print(f\"Accuracy (by rounding mean prediction): {accuracy:.4f}\")\n",
    "\n",
    "try:\n",
    "    final_lengthscale = model.covar_module.base_kernel.lengthscale.item()\n",
    "    final_outputscale = model.covar_module.outputscale.item()\n",
    "    final_noise = likelihood.noise.item()\n",
    "    print(f\"Trained Params: Lengthscale={final_lengthscale:.3f}, Outputscale={final_outputscale:.3f}, Noise={final_noise:.3f}\")\n",
    "except AttributeError:\n",
    "    print(\"Could not retrieve final hyperparameters.\")\n",
    "\n",
    "logger.info(\"Script finished.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f47e8049-ce36-4940-8f80-4d2090be8ff6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f01017b6-156b-4595-a94a-befadd80f758",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d104b91-3459-4b78-9181-6b800293320a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4d5f6ec-a279-4149-b694-b057b8152d02",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpcam_env",
   "language": "python",
   "name": "gpcam_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
