{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "from botorch.fit import fit_gpytorch_model\n",
    "from botorch.acquisition import ExpectedImprovement, qExpectedImprovement\n",
    "from botorch.optim import optimize_acqf\n",
    "from botorch.generation.sampling import MaxPosteriorSampling\n",
    "from uncertaintylearning.utils import create_network, create_optimizer\n",
    "\n",
    "from uncertaintylearning.models import DEUP, MCDropout\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.neighbors import KernelDensity\n",
    "\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from torch.quasirandom import SobolEngine\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from test_functions import functions\n",
    "from examples.SMO.test_functions import functions, bounds as boundsx\n",
    "from uncertaintylearning.features.feature_generator import FeatureGenerator\n",
    "from examples.SMO.buffer import Buffer\n",
    "from examples.SMO.smo import init_buffer, optimize, make_feature_generator, one_step_acquisition\n",
    "from uncertaintylearning.models.mcdropout import MCDropout\n",
    "from uncertaintylearning.models.ensemble import Ensemble\n",
    "\n",
    "from copy import deepcopy\n",
    "from itertools import product\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fct_name = 'levi_n13'\n",
    "fct_name = 'multi_optima'\n",
    "fct = functions[fct_name]\n",
    "dim, bounds = boundsx[fct_name]\n",
    "noise = 0\n",
    "f = lambda x: fct(x, 0)\n",
    "def invsoftplus(x, beta=1):\n",
    "    return 1. / beta * (torch.log((beta * x).exp() - 1))\n",
    "\n",
    "X = (bounds[1] - bounds[0]) * torch.rand(1000, dim) + bounds[0]\n",
    "Y = f(X)\n",
    "if dim == 1:\n",
    "    X = torch.arange(bounds[0], bounds[1], 0.01).reshape(-1, 1)\n",
    "    Y = f(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## COMPARE WITH MANY SEEDS TO GP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_seeds = 5\n",
    "n_steps = 50\n",
    "res_gp = np.zeros((n_seeds, 1 + n_steps))\n",
    "\n",
    "res_ensemble = np.zeros((n_seeds, 1 + n_steps))\n",
    "res_mcdropout = np.zeros((n_seeds, 1 + n_steps))\n",
    "    \n",
    "res_gpdeup = np.zeros((n_seeds, 1 + n_steps))\n",
    "use_log_unc=True\n",
    "features = 'xv'\n",
    "\n",
    "for seed in range(n_seeds):\n",
    "    torch.manual_seed(10 + seed)\n",
    "    X_init = (bounds[1] - bounds[0]) * torch.rand(6, dim) + bounds[0]\n",
    "    Y_init = f(X_init)\n",
    "    \n",
    "    print(f'Seed {seed}, Y_init_max {Y_init.max().item()}')\n",
    "    outs_gp = optimize(f, bounds, X_init, Y_init, model_type=\"gp\", plot_stuff=False, domain=X, domain_image=Y, n_steps=n_steps)\n",
    "    res_gp[seed] = outs_gp[0]\n",
    "    \n",
    "    nets = [create_network(dim, 1, 128, 'relu', False, 3) for i in range(3)]\n",
    "    opts = [create_optimizer(nets[i], 1e-3) for i in range(3)]\n",
    "    print(f'Seed {seed}, Y_init_max {Y_init.max().item()}')\n",
    "    outs_gp = optimize(f, bounds, X_init, Y_init, model_type=\"ensemble\", networks=nets, optimizers=opts, features=features,\n",
    "                   epochs=200, plot_stuff=False, domain=X, domain_image=Y, n_steps=n_steps)\n",
    "    res_ensemble[seed] = outs_gp[0]\n",
    "\n",
    "    network = create_network(dim, 1, 128, 'relu', False, 3, 0.3)\n",
    "    optimizer = create_optimizer(network, 1e-3)\n",
    "    mcdropout_model = MCDropout(X_init, Y_init, network, optimizer, batch_size=64)\n",
    "    print(f'Seed {seed}, Y_init_max {Y_init.max().item()}')\n",
    "    outs_gp = optimize(f, bounds, X_init, Y_init, model_type=\"mcdropout\", networks=network, optimizers=optimizer, features=features,\n",
    "                   epochs=200, plot_stuff=False, domain=X, domain_image=Y, n_steps=n_steps)\n",
    "    res_mcdropout[seed] = outs_gp[0]\n",
    "\n",
    "    \n",
    "    \n",
    "    print(f'Y_init_max {Y_init.max().item()}')\n",
    "    networks = {\n",
    "    'e_predictor': create_network(len(features) + (dim - 1 if 'x' in features else 0),\n",
    "                                  1, 128, 'relu', False if use_log_unc else True, 3),\n",
    "    'f_predictor': create_network(dim, 1, 128, 'relu', False, 3)\n",
    "    }\n",
    "    optimizers = {\n",
    "              'e_optimizer': create_optimizer(networks['e_predictor'], 1e-3),\n",
    "              'f_optimizer': create_optimizer(networks['f_predictor'], 1e-3)\n",
    "             }\n",
    "    outs = optimize(f, bounds, X_init, Y_init, networks=networks, optimizers=optimizers, features=features, plot_stuff=False,\n",
    "                    n_steps=n_steps, epochs=200, domain=X, domain_image=Y, print_each=100, use_log_unc=True, estimator='gp')\n",
    "    res_gpdeup[seed] = outs[0]\n",
    "    \n",
    "    plt.plot(range(1 + n_steps), res_gp[:seed+1].mean(0),  label='Average maximum value reached by GP')\n",
    "    plt.plot(range(1 + n_steps), res_ensemble[:seed+1].mean(0),  label='Average maximum value reached by Ensemble')\n",
    "    plt.plot(range(1 + n_steps), res_mcdropout[:seed+1].mean(0),  label='Average maximum value reached by MCDropout')\n",
    "    plt.plot(range(1 + n_steps), res_gpdeup[:seed+1].mean(0), label='Average maximum value reached by DEUP-gp-xv')\n",
    "    plt.legend()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IMPORT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "string = \"\"\"\n",
    "some information if you want\n",
    "\"\"\"\n",
    "pickle.dump({'gp': res_gp, 'gpdeup': res_gpdeup,\n",
    "             'mcdropout': res_mcdropout, 'ensemble': res_ensemble, 'string': string,\n",
    "            }, open('YOURFILENAMEHERE', 'wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
