{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from gpytorch.kernels import RBFKernel, ScaleKernel\n",
    "from gpytorch.priors import GammaPrior, UniformPrior\n",
    "import pyro\n",
    "from pyro.infer.mcmc import NUTS, MCMC\n",
    "import gpytorch\n",
    "from copy import deepcopy\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "torch.set_default_dtype(torch.double)\n",
    "seed = 777\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "def HYX(X, model):\n",
    "    return -(likelihood(model(X.view(-1,1))).log_prob(likelihood(model(X.view(-1,1)).sample(torch.Size([10000]))).sample())).mean().item()\n",
    "\n",
    "def HYXT(X, model, theta):\n",
    "    htheta_i = torch.zeros(num_samples)\n",
    "    likelihood.eval()\n",
    "    for i in range(num_samples):\n",
    "        # set theta\n",
    "        model.covar_module.base_kernel.kernels[0].lengthscale = theta[i]\n",
    "        model.eval()\n",
    "        htheta_i[i] = -(likelihood(model(X.view(-1,1))).log_prob(likelihood(model(X.view(-1,1)).sample(torch.Size([10000]))).sample())).mean()\n",
    "    return torch.mean(htheta_i).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_amp, theta_ls = torch.tensor([1., 5])\n",
    "phi1_ls = torch.tensor([2.5])\n",
    "ktheta = ScaleKernel(RBFKernel(ard_num_dims=1))\n",
    "bphi1 = RBFKernel(ard_num_dims=1)\n",
    "noise = .0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ktheta.outputscale = 1\n",
    "ktheta.base_kernel.lengthscale = theta_ls\n",
    "bphi1.lengthscale = phi1_ls\n",
    "bphi1.outputscale = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "gridsize = 101\n",
    "Xgrid = torch.linspace(0, 50, gridsize)\n",
    "jitter = 1e-8\n",
    "evalktheta = ktheta(Xgrid).to_dense() + jitter/2 * torch.eye(gridsize)\n",
    "evalkphi = bphi1(Xgrid).to_dense()  + jitter/2 * torch.eye(gridsize)\n",
    "\n",
    "GPtheta = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(gridsize), evalktheta)\n",
    "GPphi = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(gridsize), evalkphi)\n",
    "sampletheta = GPtheta.sample()\n",
    "samplephi = GPphi.sample()\n",
    "sample = sampletheta + samplephi\n",
    "\n",
    "nobs = 5\n",
    "idx_obs = torch.randperm(gridsize)[:nobs]\n",
    "\n",
    "X_1, Y_1 = Xgrid[idx_obs], sample[idx_obs] + noise * torch.randn(nobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### INITIAL PRIORS\n",
    "\n",
    "cc = 3\n",
    "scale = 1.25\n",
    "theta_ls_prior =  GammaPrior(concentration=cc, rate=scale**(-1))\n",
    "bphi1_ls_prior =  GammaPrior(concentration=cc, rate=scale**(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### MODEL \n",
    "\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood, covar_module):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = covar_module\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "### FIT\n",
    "\n",
    "kmodeltheta = RBFKernel(ard_num_dims=1)\n",
    "kmodelphi = RBFKernel(ard_num_dims=1)\n",
    "\n",
    "kmodel = ScaleKernel(kmodeltheta + kmodelphi)\n",
    "kmodel.outputscale = 1.\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
    "model = ExactGPModel(X_1, Y_1, likelihood, covar_module=kmodel)\n",
    "\n",
    "model.covar_module.base_kernel.kernels[0].register_prior(\"lengthscale_prior\", theta_ls_prior, \"lengthscale\")\n",
    "model.covar_module.base_kernel.kernels[1].register_prior(\"lengthscale_prior\", bphi1_ls_prior, \"lengthscale\")\n",
    "\n",
    "likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.05), \"noise\")\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "def pyro_model(x, y):\n",
    "    with gpytorch.settings.fast_computations(False, False, False):\n",
    "        sampled_model = model.pyro_sample_from_prior()\n",
    "        output = sampled_model.likelihood(sampled_model(x))\n",
    "        pyro.sample(\"obs\", output, obs=y)\n",
    "    return y\n",
    "\n",
    "num_samples = 500\n",
    "warmup_steps = 50\n",
    "\n",
    "nuts_kernel = NUTS(pyro_model)\n",
    "mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)\n",
    "mcmc_run.run(X_1, Y_1)\n",
    "model.pyro_load_from_samples(mcmc_run.get_samples())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "likelihood.eval();\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "HY = torch.tensor([HYX(X, model) for X in Xgrid])\n",
    "HY_theta = torch.tensor([HYXT(X, deepcopy(model), theta) for X in Xgrid])\n",
    "etig = HY - HY_theta\n",
    "etsig = HY_theta\n",
    "etig /= etig.max()\n",
    "etsig /= etsig.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure(go.Scatter(\n",
    "    x=Xgrid.ravel(), y=etig, mode=\"lines\", name=\"ETIG\",\n",
    "    line=dict(color=\"green\")\n",
    "))\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=Xgrid.ravel(), y=etsig, mode=\"lines\", name=\"ELIG\",\n",
    "    line=dict(color=\"orange\")\n",
    "))\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=X_1, y=0.*X_1, mode=\"markers\", \n",
    "    marker=dict(color=\"black\", symbol=\"star\"), name=\"Initial observations\"\n",
    "))\n",
    "fig.update_xaxes(title=r\"$\\boldsymbol{x}$\")\n",
    "fig.update_yaxes(title=\"Acquisition function values\")\n",
    "fig.update_layout(\n",
    "    height=500., width=500., legend=dict(y=.99, x=.99, xanchor=\"right\"),\n",
    "    template=\"plotly_white\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
