{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.gaussian_process import LinearGaussianProcess\n",
    "from src.utils.misc import random_fourier_features\n",
    "import matplotlib.pyplot as plt\n",
    "import visualizations.plot_settings as plot_settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_settings.set_latex_settings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 1\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define groundtruth function and observation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_func = lambda x: torch.sin(2 * torch.pi * x)\n",
    "noise_std = 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Gaussian process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "gp = LinearGaussianProcess(\n",
    "    n_features=10000,\n",
    "    nar=noise_std,\n",
    "    device=torch.get_default_device(),\n",
    "    dtype=torch.float64\n",
    "    )\n",
    "rbf_length_scale=0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define domain discretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.linspace(0, 1, 1000)\n",
    "X_features = random_fourier_features(X.view(-1, 1), gp.n_features, ls=rbf_length_scale, random_normals=torch.randn(1, gp.n_features//2, device=X.device))\n",
    "true_function = true_func(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Thompson sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bo_steps = 5\n",
    "X_train = []\n",
    "y_train = []\n",
    "for step in range(bo_steps+1):\n",
    "    mean = gp.posterior_mean(X_features)\n",
    "    covariance_matrix = gp.posterior_cov(X_features)\n",
    "    L, Q = torch.linalg.eigh(covariance_matrix)\n",
    "    sqrt_covariance_matrix = (Q * torch.sqrt(torch.maximum(L, torch.zeros_like(L)).unsqueeze(0))) @ Q.t().conj()\n",
    "    posterior_sample = mean + sqrt_covariance_matrix @ torch.randn_like(mean)\n",
    "    thompson_sample = torch.argmax(posterior_sample)\n",
    "    if step != bo_steps:\n",
    "        X_train.append(X[thompson_sample].view(()))\n",
    "        y_train.append(true_function[thompson_sample] + noise_std * torch.randn(()))\n",
    "        gp.add_observation(X_features[thompson_sample, :], obs=y_train[-1], perform_marginal_likelihood_maximization=True, min_obs=4)\n",
    "X_train = torch.tensor(X_train)\n",
    "y_train = torch.tensor(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "posterior_mean = mean\n",
    "posterior_std = torch.sqrt(torch.diag(covariance_matrix))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot posterior GP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot results\n",
    "fig1, (ax1) = plt.subplots(1, 1, figsize=(plot_settings.document_width/2, 2.0))\n",
    "fig2, (ax2) = plt.subplots(1, 1, figsize=(plot_settings.document_width/2, 2.0))\n",
    "#plt.plot(X.numpy(), true_function.numpy(), 'k:', label=\"True Function\")\n",
    "ax1.scatter(X_train.numpy(), y_train.numpy(), color='red', label=\"Observations\", zorder=5, marker=\"o\", s=25, edgecolors='white', linewidths=0.1)\n",
    "ax1.plot(X.numpy(), posterior_mean.numpy(), 'b-', label=\"Posterior mean\")\n",
    "ax1.fill_between(\n",
    "    X.flatten().numpy(),\n",
    "    (posterior_mean - posterior_std).numpy(),\n",
    "    (posterior_mean + posterior_std).numpy(),\n",
    "    color='blue', alpha=0.2, label=\"Confidence interval\"\n",
    ")\n",
    "ax1.plot(X.numpy(), posterior_sample.numpy(), 'g:', label=\"Sample from posterior\")\n",
    "ax1.scatter(X[thompson_sample:thompson_sample+1].numpy(), posterior_sample[thompson_sample:thompson_sample+1].numpy(), color='red', label=\"Thompson sample\", zorder=5, marker='+', s=100)\n",
    "\n",
    "\n",
    "# coarsely discretize\n",
    "sampling_ratio = 100\n",
    "X2 = X[::sampling_ratio]\n",
    "posterior_mean2 = posterior_mean[::sampling_ratio]\n",
    "posterior_std2 = posterior_std[::sampling_ratio]\n",
    "posterior_sample2 = posterior_sample[::sampling_ratio]\n",
    "thompson_sample2 = torch.argmax(posterior_sample2)\n",
    "\n",
    "#ax2.scatter(X_train.numpy(), y_train.numpy(), color='red', label=\"Observations\", zorder=5, marker=\"o\")\n",
    "markers, caps, bars = ax2.errorbar(\n",
    "    X2.flatten().numpy(),\n",
    "    posterior_mean2.numpy(),\n",
    "    posterior_std2.numpy(),\n",
    "    0,\n",
    "    color='blue', \n",
    "    marker='s',\n",
    "    markersize=10,\n",
    "    elinewidth=10.0,\n",
    "    alpha=0.5,\n",
    "    ls='none',\n",
    "    label=r\"Mean \\& confidence interval\",\n",
    ")\n",
    "[bar.set_alpha(0.2) for bar in bars]\n",
    "[cap.set_alpha(0.2) for cap in caps]\n",
    "ax2.scatter(X2.numpy(), posterior_sample2.numpy(), color='green', label=\"Sample from posterior\", zorder=5, marker=\"D\", s=40, edgecolors='white', linewidths=0.1)\n",
    "ax2.scatter(X2[thompson_sample2:thompson_sample2+1].numpy(), posterior_sample2[thompson_sample2:thompson_sample2+1].numpy(), color='red', label=\"Thompson sample\", zorder=5, marker='+', s=100)\n",
    "\n",
    "\n",
    "ax1.set_ylim(-1.4, 1.35)\n",
    "ax2.set_ylim(-1.4, 1.35)\n",
    "ax1.set_xlabel(\"X\")\n",
    "ax2.set_xlabel(\"X\")\n",
    "ax1.set_ylabel(\"Reward\")\n",
    "ax2.set_ylabel(\"Reward\")\n",
    "#ax1.legend(loc='upper right')\n",
    "#ax2.legend(loc='upper right')\n",
    "ax1.set_axis_off()\n",
    "ax2.set_axis_off()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig1.savefig(\"visualizations/results/thompson_sampling_continuous.pgf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig2.savefig(\"visualizations/results/thompson_sampling_discrete.pgf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd visualizations/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! bash pgf_compiler.sh thompson_sampling_continuous"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! bash pgf_compiler.sh thompson_sampling_discrete"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
