{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66e3ef3-4260-4c23-ae72-212d1326bc2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== Cholesky Decomposition =====\n",
    "def one_step_cholesky(\n",
    "    top_left: torch.Tensor, K_Xθ: torch.Tensor, K_θθ: torch.Tensor, A_inv: torch.Tensor\n",
    ") -> torch.Tensor:\n",
    "    top_right = top_left @ (A_inv @ K_Xθ)\n",
    "    bot_left = torch.zeros_like(top_right).transpose(-1, -2)\n",
    "    bot_right = torch.cholesky(\n",
    "        K_θθ - top_right.transpose(-1, -2) @ top_right, upper=True\n",
    "    )\n",
    "    return torch.cat(\n",
    "        [\n",
    "            torch.cat([top_left, top_right], dim=-1),\n",
    "            torch.cat([bot_left, bot_right], dim=-1),\n",
    "        ],\n",
    "        dim=-2,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527958b4-e78b-4c0a-ba5f-a703a36f9c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "import torch\n",
    "import gpytorch\n",
    "import botorch\n",
    "\n",
    "# ===== Acquisition Function to Sample Points for Gradient Information =====\n",
    "class GradientInformation(botorch.acquisition.AnalyticAcquisitionFunction):\n",
    "    def __init__(self, model):\n",
    "        super().__init__(model)\n",
    "    def update_theta_i(self, theta_i: torch.Tensor):\n",
    "        # Updates the current parameters.\n",
    "        if not torch.is_tensor(theta_i):\n",
    "            theta_i = torch.tensor(theta_i)\n",
    "        self.theta_i = theta_i\n",
    "        self.update_K_xX_dx()\n",
    "    def update_K_xX_dx(self):\n",
    "        # When new x is given update K_xX_dx\n",
    "        X = self.model.train_inputs[0]\n",
    "        x = self.theta_i.view(-1, self.model.D)\n",
    "        self.K_xX_dx_part = self._get_KxX_dx(x, X)\n",
    "    def _get_KxX_dx(self, x, X) -> torch.Tensor:\n",
    "        # Computes the analytic derivative of the kernel K(x,X)\n",
    "        N = X.shape[0]\n",
    "        n = x.shape[0]\n",
    "        K_xX = self.model.covar_module(x, X).evaluate()\n",
    "        lengthscale = self.model.covar_module.base_kernel.lengthscale.detach()\n",
    "        return (\n",
    "            -torch.eye(self.model.D, device=X.device)\n",
    "            / lengthscale ** 2\n",
    "            @ (\n",
    "                (x.view(n, 1, self.model.D) - X.view(1, N, self.model.D))\n",
    "                * K_xX.view(n, N, 1)\n",
    "            ).transpose(1, 2)\n",
    "        )\n",
    "\n",
    "    \n",
    "    @botorch.utils.transforms.t_batch_mode_transform(expected_q=1)\n",
    "    def forward(self, thetas: torch.Tensor) -> torch.Tensor:\n",
    "        # Evaluate the acquisition function on the candidate set thetas\n",
    "        sigma_n = 0\n",
    "        D = self.model.D\n",
    "        X = self.model.train_inputs[0]\n",
    "        x = self.theta_i.view(-1, D)\n",
    "        variances = []\n",
    "        for theta in thetas:\n",
    "            theta = theta.view(-1, D)\n",
    "            K_Xθ = self.model.covar_module(X, theta).evaluate()\n",
    "            K_θθ = self.model.covar_module(theta).evaluate() \n",
    "            # Get Cholesky factor.\n",
    "            L = one_step_cholesky(\n",
    "                top_left=self.model.get_L_lower().transpose(-1, -2),\n",
    "                K_Xθ=K_Xθ,\n",
    "                K_θθ=K_θθ,\n",
    "                A_inv=self.model.get_KXX_inv(),\n",
    "            )\n",
    "            # Get K_XX_inv.\n",
    "            K_XX_inv = torch.cholesky_inverse(L, upper=True)\n",
    "            # get K_xX_dx\n",
    "            K_xθ_dx = self._get_KxX_dx(x, theta)\n",
    "            K_xX_dx = torch.cat([self.K_xX_dx_part, K_xθ_dx], dim=-1)\n",
    "            # Compute_variance.\n",
    "            variance_d = -K_xX_dx @ K_XX_inv @ K_xX_dx.transpose(1, 2)\n",
    "            variances.append(torch.trace(variance_d.view(D, D)).view(1))\n",
    "        return -torch.cat(variances, dim=0)\n",
    "\n",
    "\n",
    "def optimize_acqf_custom_bo(\n",
    "    # Function to optimize the GradientInformation acquisition function for custom Bayesian optimization\n",
    "    acq_func: botorch.acquisition.AcquisitionFunction,\n",
    "    bounds: torch.Tensor,\n",
    "    q: int,\n",
    "    num_restarts: int,\n",
    "    raw_samples: int,\n",
    ") -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    candidates, acq_value = botorch.optim.optimize_acqf(\n",
    "        acq_function=acq_func,\n",
    "        bounds=bounds,\n",
    "        q=q,  \n",
    "        num_restarts=num_restarts,\n",
    "        raw_samples=raw_samples,  \n",
    "        options={'nonnegative': True, 'batch_limit': 5},\n",
    "        return_best_only=True,\n",
    "        sequential=False,\n",
    "    )\n",
    "    new_x = candidates.detach()\n",
    "    return new_x, acq_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e1760e-3e4a-4443-8a83-40548af0b775",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== An exact Gaussian Process (GP) Model =====\n",
    "class ExactGPSEModel(gpytorch.models.ExactGP, botorch.models.gpytorch.GPyTorchModel):\n",
    "    _num_outputs = 1  # To inform GPyTorchModel API.\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        train_x: torch.Tensor,\n",
    "        train_y: torch.Tensor,\n",
    "        prior_mean=0,\n",
    "    ):\n",
    "        likelihood = gpytorch.likelihoods.GaussianLikelihood(\n",
    "            noise_constraint= None, noise_prior=None\n",
    "        )\n",
    "        if train_y is not None:\n",
    "            train_y = train_y.squeeze(-1)\n",
    "        super(ExactGPSEModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        if prior_mean != 0:\n",
    "            self.mean_module.initialize(constant=prior_mean)\n",
    "            self.mean_module.constant.requires_grad = False\n",
    "\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(\n",
    "            gpytorch.kernels.RBFKernel(\n",
    "                ard_num_dims=None,\n",
    "                lengthscale_prior=None,\n",
    "                lengthscale_constraint=None,\n",
    "            ),\n",
    "            outputscale_prior=None,\n",
    "            outputscale_constraint=None,\n",
    "        )\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# ===== Derivative of the ExactGPSEModel =====\n",
    "class DerivativeExactGPSEModel(ExactGPSEModel):\n",
    "    def __init__(\n",
    "        self,\n",
    "        D: int,\n",
    "        prior_mean=0.0,\n",
    "    ):\n",
    "        train_x_init, train_y_init = (\n",
    "            torch.empty(0, D),\n",
    "            torch.empty(0),\n",
    "        )\n",
    "        super(DerivativeExactGPSEModel, self).__init__(\n",
    "            train_x_init,\n",
    "            train_y_init,\n",
    "            prior_mean,\n",
    "        )\n",
    "\n",
    "        self.D = D\n",
    "        self.N = 0\n",
    "        self.train_xs = train_x_init\n",
    "        self.train_ys = train_y_init\n",
    "        normalize = lambda params: params\n",
    "        self.normalize = normalize\n",
    "        unnormalize = lambda params: params\n",
    "        self.unnormalize = unnormalize\n",
    "\n",
    "    def append_train_data(self, train_x, train_y):\n",
    "        # Adaptively append training data\n",
    "        self.train_xs = torch.cat([self.unnormalize(train_x), self.train_xs])\n",
    "        self.train_ys = torch.cat([train_y, self.train_ys])\n",
    "\n",
    "        self.set_train_data(\n",
    "            inputs=self.normalize(self.train_xs),\n",
    "            targets=self.train_ys,\n",
    "            strict=False,\n",
    "        )\n",
    "\n",
    "        self.N = self.train_xs.shape[0]\n",
    "\n",
    "    def update_train_data(self, train_x, train_y):\n",
    "        self.train_xs = train_x\n",
    "        self.train_ys = train_y\n",
    "\n",
    "        self.set_train_data(\n",
    "            inputs=self.normalize(self.train_xs),\n",
    "            targets=self.train_ys,\n",
    "            strict=False,\n",
    "        )\n",
    "\n",
    "        self.N = self.train_xs.shape[0]\n",
    "\n",
    "    def get_L_lower(self):\n",
    "        # Get Cholesky decomposition\n",
    "        return (\n",
    "            self.prediction_strategy.lik_train_train_covar.root_decomposition()\n",
    "            .root.to_dense().detach()\n",
    "        )\n",
    "\n",
    "    def get_KXX_inv(self):\n",
    "        # Get the inverse matrix of K(X,X)\n",
    "        L_inv_upper = self.prediction_strategy.covar_cache.detach()\n",
    "        return L_inv_upper @ L_inv_upper.transpose(0, 1)\n",
    "\n",
    "    def _get_KxX_dx(self, x):\n",
    "        # Computes the analytic derivative of the kernel K(x,X)\n",
    "        X = self.train_inputs[0]\n",
    "        n = x.shape[0]\n",
    "        K_xX = self.covar_module(x, X).evaluate()\n",
    "        lengthscale = self.covar_module.base_kernel.lengthscale.detach()\n",
    "        return (\n",
    "            -torch.eye(self.D, device=x.device)\n",
    "            / lengthscale ** 2\n",
    "            @ (\n",
    "                (x.view(n, 1, self.D) - X.view(1, self.N, self.D))\n",
    "                * K_xX.view(n, self.N, 1)\n",
    "            ).transpose(1, 2)\n",
    "        )\n",
    "\n",
    "    def _get_Kxx_dx2(self):\n",
    "        # Computes the analytic second derivative of the kernel K(x,x)\n",
    "        lengthscale = self.covar_module.base_kernel.lengthscale.detach()\n",
    "        sigma_f = self.covar_module.outputscale.detach()\n",
    "        return (\n",
    "            torch.eye(self.D, device=lengthscale.device) / lengthscale ** 2\n",
    "        ) * sigma_f\n",
    "\n",
    "    def posterior_derivative(self, x):\n",
    "        # Computes the posterior of the derivative of the GP w.r.t. the given test points x\n",
    "        if self.prediction_strategy is None:\n",
    "            self.posterior(x)  # Call this to update prediction strategy of GPyTorch.\n",
    "        K_xX_dx = self._get_KxX_dx(x)\n",
    "        mean_d = K_xX_dx @ self.get_KXX_inv() @ self.train_targets\n",
    "        variance_d = (\n",
    "            self._get_Kxx_dx2() - K_xX_dx @ self.get_KXX_inv() @ K_xX_dx.transpose(1, 2)\n",
    "        )\n",
    "        variance_d = variance_d.clamp_min(1e-9)\n",
    "\n",
    "        return mean_d, variance_d\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
