{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fc4d0481af0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from gpytorch.kernels import RBFKernel, ScaleKernel\n",
    "from gpytorch.priors import GammaPrior\n",
    "import gpytorch\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "torch.set_default_dtype(torch.double)\n",
    "seed = 777\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cc = 3\n",
    "scale = 1.25\n",
    "\n",
    "theta_ls_prior =  GammaPrior(concentration=cc, rate=scale**(-1))\n",
    "bphi1_ls_prior =  GammaPrior(concentration=cc, rate=scale**(-1))\n",
    "\n",
    "Xgrid = torch.linspace(0, 50, 101)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_GP(kernel, target_ls, nuisance_ls, x):\n",
    "    kernel.base_kernel.kernels[0].lengthscale = target_ls\n",
    "    kernel.base_kernel.kernels[1].lengthscale = nuisance_ls\n",
    "    evalkernel = kernel(x).to_dense()\n",
    "    return torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(len(x)), evalkernel\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GP_sample(theta, phi, x, eval_p=False):\n",
    "    kernel = ScaleKernel(RBFKernel(ard_num_dims=1) +  RBFKernel(ard_num_dims=1))\n",
    "    kernel.outputscale = 1\n",
    "    res = torch.zeros((len(theta), len(x)))\n",
    "    if eval_p:\n",
    "        pres = torch.zeros((len(theta),))\n",
    "    for i, (t,p) in enumerate(zip(theta, phi)):\n",
    "        GPtheta = set_GP(kernel, t, p, x)\n",
    "        res[i] = GPtheta.sample()\n",
    "        if eval_p:\n",
    "            pres[i] = GPtheta.log_prob(res[i])\n",
    "    if eval_p:\n",
    "        return res, pres\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10000\n",
    "M = 100\n",
    "X = Xgrid[[50,52]]\n",
    "theta_mc = theta_ls_prior.sample([N])\n",
    "phi_mc = bphi1_ls_prior.sample([N])\n",
    "FNS, pFNS = GP_sample(theta_mc, phi_mc, X, eval_p=True)\n",
    "rT = torch.zeros(N)\n",
    "noiselevel = .01\n",
    "likelihood.noise = .01\n",
    "for n in range(N):\n",
    "    YMC = FNS[n] + noiselevel * torch.randn((M, len(X)))\n",
    "    phiprime = bphi1_ls_prior.sample([M])\n",
    "    FNSprime = GP_sample(torch.ones(M) * theta_mc[n], phiprime, X)\n",
    "    pyjtheta = torch.exp(torch.tensor(likelihood(FNSprime.reshape(M,1, len(X))).log_prob(YMC))).mean(dim=0)\n",
    "    pyj = torch.exp(torch.tensor(likelihood(FNS.reshape(N, 1, len(X))).log_prob(YMC))).mean(dim=0)\n",
    "    rT[n] = (torch.log(pyjtheta.prod(dim=1)) - torch.log(pyj.prod(dim=1))).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "(rT < 0.).to(float).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ppsi = pFNS.clone()\n",
    "ppsi += bphi1_ls_prior.log_prob(phi_mc)\n",
    "ppsi = torch.exp(ppsi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure(go.Scatter(\n",
    "    x=ppsi[rT >= 0.].detach().numpy(), y=rT[rT >= 0.].detach().numpy(), mode=\"markers\", \n",
    "    name=\"Positive interference\", marker=dict(color=\"green\")\n",
    "))\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=ppsi[(rT < 0.) & (rT > -10.)].detach().numpy(), y=rT[(rT < 0.) & (rT > -10.)].detach().numpy(), \n",
    "    mode=\"markers\", \n",
    "    name=\"Negative interference\", marker=dict(color=\"orange\")\n",
    "))\n",
    "fig.update_xaxes(title=r\"$p(\\boldsymbol{\\psi} | \\boldsymbol{\\theta}^\\star)$\")\n",
    "fig.update_yaxes(title=r\"$r^\\star(\\boldsymbol{x})$\")\n",
    "fig.update_layout(\n",
    "    height=500., width=500., legend=dict(xanchor=\"right\", x=.99, y=.99),\n",
    "    template=\"plotly_white\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
