{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff153539",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import multiprocessing\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "tkwargs = {\n",
    "    \"dtype\": torch.double,\n",
    "    \"device\": torch.device(\"cpu\"),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "85a49017",
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.models.gp_regression import FixedNoiseGP\n",
    "from botorch.models.model_list_gp_regression import ModelListGP\n",
    "from botorch.models.transforms.outcome import Standardize\n",
    "from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood\n",
    "from botorch.utils.transforms import unnormalize, normalize\n",
    "from botorch.utils.sampling import draw_sobol_samples\n",
    "from botorch.acquisition import AcquisitionFunction\n",
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "from botorch import fit_gpytorch_mll\n",
    "from botorch.test_functions.multi_objective import DTLZ2, GMM, ZDT2, VehicleSafety\n",
    "#NOISE_SE = torch.tensor([0.00, 0.00, 0.00], **tkwargs)\n",
    "NOISE_SE = torch.tensor([0.00, 0.00], **tkwargs)\n",
    "#problem = DTLZ2(num_objectives=3, dim=4, negate=True).to(**tkwargs)\n",
    "problem = GMM(negate=True).to(**tkwargs)\n",
    "#problem = ZDT2(dim=6, negate=True).to(**tkwargs)\n",
    "#problem = VehicleSafety(negate=True).to(**tkwargs)\n",
    "def initialize_model(train_x, train_obj):\n",
    "    # define models for objective and constraint\n",
    "    #train_x = normalize(train_x, problem.bounds)\n",
    "    models = []\n",
    "    for i in range(train_obj.shape[-1]):\n",
    "        train_y = train_obj[..., i : i + 1]\n",
    "        train_yvar = torch.full_like(train_y, NOISE_SE[i] ** 2)\n",
    "        models.append(\n",
    "            FixedNoiseGP(\n",
    "                train_x, train_y, train_yvar, outcome_transform=Standardize(m=1)\n",
    "            )\n",
    "        )\n",
    "    model = ModelListGP(*models)\n",
    "    mll = SumMarginalLogLikelihood(model.likelihood, model)\n",
    "    return mll, model\n",
    "\n",
    "def generate_initial_data(n=10, seed=42):\n",
    "    # generate training data\n",
    "    train_x = draw_sobol_samples(bounds=problem.bounds, n=n, q=1, seed=seed).squeeze(1)\n",
    "    train_obj_true = problem(train_x)\n",
    "    train_obj = train_obj_true + torch.randn_like(train_obj_true) * NOISE_SE\n",
    "    return train_x, train_obj, train_obj_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "47e97107",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from botorch.models.model import Model\n",
    "from typing import Any, Callable, Optional\n",
    "from botorch.optim.optimize import optimize_acqf, optimize_acqf_list\n",
    "from botorch.acquisition.objective import GenericMCObjective\n",
    "from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization\n",
    "from botorch.utils.multi_objective.box_decompositions.non_dominated import (\n",
    "    FastNondominatedPartitioning,\n",
    ")\n",
    "from botorch.utils.multi_objective.box_decompositions.dominated import (\n",
    "    DominatedPartitioning,\n",
    ")\n",
    "from botorch.acquisition.multi_objective.monte_carlo import (\n",
    "    qExpectedHypervolumeImprovement,\n",
    ")\n",
    "from botorch.utils.sampling import sample_simplex\n",
    "from botorch.acquisition.multi_objective.utils import (\n",
    "    sample_optimal_points,\n",
    "    random_search_optimizer,\n",
    "    compute_sample_box_decomposition\n",
    ")\n",
    "from botorch.sampling.normal import SobolQMCNormalSampler\n",
    "from torch import Tensor\n",
    "from botorch.utils.multi_objective.pareto import is_non_dominated\n",
    "\n",
    "\n",
    "NUM_RESTARTS = 10 \n",
    "RAW_SAMPLES = 512 \n",
    "\n",
    "\n",
    "def optimize_qehvi_and_get_observation(model, train_x, train_obj, sampler, X_test, q):\n",
    "    \"\"\"Optimizes the qEHVI acquisition function, and returns a new candidate and observation.\"\"\"\n",
    "    # partition non-dominated space into disjoint rectangles\n",
    "    with torch.no_grad():\n",
    "        pred = model.posterior(train_x).mean\n",
    "        \n",
    "    partitioning = FastNondominatedPartitioning(\n",
    "        ref_point=torch.tensor(problem.ref_point).reshape(-1).to(**tkwargs),\n",
    "        Y=pred,\n",
    "    )\n",
    "    acq_func = qExpectedHypervolumeImprovement(\n",
    "        model=model,\n",
    "        ref_point=torch.tensor(problem.ref_point).reshape(-1).to(**tkwargs),\n",
    "        partitioning=partitioning,\n",
    "        sampler=sampler,\n",
    "    )\n",
    "    # optimize\n",
    "    candidates, _ = optimize_acqf(\n",
    "        acq_function=acq_func,\n",
    "        bounds=problem.bounds,\n",
    "        q=q,\n",
    "        num_restarts=NUM_RESTARTS,\n",
    "        raw_samples=RAW_SAMPLES, \n",
    "        options={\"batch_limit\": 5, \"maxiter\": 200},\n",
    "        sequential=False,\n",
    "        # set equality constraints to make sure sum of composition is 1\n",
    "        #equality_constraints=[(indices, coefficients, rhs)]\n",
    "        #inequality_constraints=... if needed\n",
    "    )\n",
    "    \n",
    "    # Compute pairwise distances: result shape (M, N)\n",
    "    distances = np.linalg.norm(X_test[None, :, :] - candidates[:, None, :], axis=2)\n",
    "\n",
    "    # For each candidate, find the nearest X_test index\n",
    "    nearest_indices = np.argmin(distances, axis=1)\n",
    "    \n",
    "    new_x = X_test[nearest_indices].reshape(q, N_dim)\n",
    "    new_obj_true = problem(new_x)    \n",
    "    new_obj = new_obj_true\n",
    "    \n",
    "    return new_x, new_obj, new_obj_true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4bfb272c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def distance_XX(X, X_star):\n",
    "    d_sum = 0\n",
    "    for x_star in X_star:\n",
    "        d_list = torch.norm(X - x_star, dim=1)\n",
    "        d_sum += torch.min(d_list)\n",
    "    return d_sum / len(X_star)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b763fe2b",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'generate_initial_data' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 42\u001b[0m\n\u001b[1;32m     38\u001b[0m edmin_total\u001b[38;5;241m=\u001b[39m[]\n\u001b[1;32m     40\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(rep):\n\u001b[0;32m---> 42\u001b[0m     train_x_qehvi, train_y_qehvi, _ \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_initial_data\u001b[49m(seed\u001b[38;5;241m=\u001b[39mj)\n\u001b[1;32m     43\u001b[0m     data_x_qehvi\u001b[38;5;241m=\u001b[39mtrain_x_qehvi\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m     44\u001b[0m     data_y_qehvi\u001b[38;5;241m=\u001b[39mtrain_y_qehvi\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()\n",
      "\u001b[0;31mNameError\u001b[0m: name 'generate_initial_data' is not defined"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import time\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.stats import norm\n",
    "from pyDOE import *\n",
    "from copy import deepcopy\n",
    "import os\n",
    "import shutil\n",
    "from multiprocessing import Pool\n",
    "import multiprocessing\n",
    "from joblib import Parallel, delayed\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "itr=40\n",
    "N_dim=2\n",
    "N_test=1500\n",
    "N_alt=100\n",
    "N_samp=1\n",
    "N_obj=2\n",
    "MC_SAMPLES = 10\n",
    "BATCH_SIZE = 2\n",
    "opt_imp=[]\n",
    "#chosen_acq='EI'\n",
    "chosen_acq='EI_Botorch'\n",
    "#chosen_acq='PES_Botorch'\n",
    "#chosen_acq='MFDS_Botorch'\n",
    "#chosen_acq='TPE_Optuna'\n",
    "#chosen_acq='RS_Botorch'\n",
    "verbose = True\n",
    "rep=20\n",
    "hv_total=[]\n",
    "edmin_total=[]\n",
    "\n",
    "for j in range(rep):\n",
    "    \n",
    "    train_x_qehvi, train_y_qehvi, _ = generate_initial_data(seed=j)\n",
    "    data_x_qehvi=train_x_qehvi.detach().numpy()\n",
    "    data_y_qehvi=train_y_qehvi.detach().numpy()\n",
    "    train_x_qehvi=torch.tensor(train_x_qehvi).to(**tkwargs)\n",
    "    train_y_qehvi=torch.tensor(train_y_qehvi).to(**tkwargs)\n",
    "    mll_qehvi, model_qehvi = initialize_model(train_x_qehvi, train_y_qehvi)\n",
    "    \n",
    "    X_test_all=lhs(N_dim,N_test)\n",
    "    X_test_all=torch.tensor(X_test_all)\n",
    "    #X_test_all=unnormalize(X_test_all, problem.bounds)\n",
    "    Y_test_all = problem(X_test_all)\n",
    "    X_test_all = torch.tensor(X_test_all).to(**tkwargs)\n",
    "    \n",
    "    \n",
    "    # Find PF ground truth and calculate HV\n",
    "    pareto_mask_test_all = is_non_dominated(Y_test_all)\n",
    "    Y_pf = Y_test_all[pareto_mask_test_all]  \n",
    "    bd_test_all = DominatedPartitioning(ref_point=problem.ref_point, Y=Y_pf)\n",
    "    volume_test_all = bd_test_all.compute_hypervolume().item()\n",
    "    \n",
    "    # Initialize edmin\n",
    "    X_pf = torch.tensor(X_test_all[pareto_mask_test_all]).to(**tkwargs) \n",
    "    edmin = distance_XX(train_x_qehvi, X_pf).reshape(1,1)\n",
    "    print(edmin)\n",
    "    # Initialize hv\n",
    "    pareto_mask_train = is_non_dominated(train_y_qehvi)\n",
    "    Y_pf_train = train_y_qehvi[pareto_mask_train] \n",
    "    bd_train = DominatedPartitioning(ref_point=problem.ref_point, Y=Y_pf_train)\n",
    "    hv_truth = np.array(bd_train.compute_hypervolume().item()).reshape(1,1) \n",
    "        \n",
    "    iteration=0\n",
    "    if chosen_acq == 'EI_Botorch':\n",
    "    \n",
    "        while iteration<itr:\n",
    "            iteration += 1\n",
    "            t0 = time.monotonic()\n",
    "            \n",
    "            X_test=lhs(N_dim,N_alt)\n",
    "            X_test=torch.tensor(X_test)\n",
    "            #X_test=unnormalize(X_test, problem.bounds)\n",
    "            \n",
    "            # Fit the models\n",
    "            fit_gpytorch_mll(mll_qehvi)\n",
    "\n",
    "            # Define the qEHVI acquisition module using a QMC sampler\n",
    "            qehvi_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([MC_SAMPLES]))\n",
    "\n",
    "            # Optimize acquisition function and get new observations\n",
    "            new_x_qehvi, new_y_qehvi, new_y_true_qehvi = optimize_qehvi_and_get_observation(\n",
    "                model_qehvi, train_x_qehvi, train_y_qehvi, qehvi_sampler, X_test_all, BATCH_SIZE\n",
    "            )\n",
    "            \n",
    "            # Update training points\n",
    "            train_x_qehvi = torch.cat([train_x_qehvi, new_x_qehvi.reshape(-1, N_dim)])\n",
    "            train_y_qehvi = torch.cat([train_y_qehvi, new_y_qehvi.reshape(-1, N_obj)])\n",
    "            data_x_qehvi=train_x_qehvi.cpu().detach().numpy()\n",
    "            data_y_qehvi=train_y_qehvi.cpu().detach().numpy()\n",
    "            \n",
    "            # Compute hypervolume\n",
    "            pareto_mask_train = is_non_dominated(train_y_qehvi)\n",
    "            Y_pf_train = train_y_qehvi[pareto_mask_train] \n",
    "            bd_train = DominatedPartitioning(ref_point=problem.ref_point, Y=Y_pf_train)\n",
    "            hv_t = np.array(bd_train.compute_hypervolume().item())    \n",
    "            hv_truth=np.concatenate((hv_truth,hv_t.reshape(1,1)))\n",
    "            \n",
    "            # Compute edmin\n",
    "            ed_t = distance_XX(train_x_qehvi, X_pf).reshape(1,1)\n",
    "            edmin = torch.cat((edmin, ed_t.reshape(1,1)))\n",
    "\n",
    "            # Reinitialize the models for next iteration\n",
    "            mll_qehvi, model_qehvi = initialize_model(train_x_qehvi, train_y_qehvi)\n",
    "\n",
    "            t1 = time.monotonic()\n",
    "\n",
    "            if verbose:\n",
    "                print(\"Iteration:\", iteration)\n",
    "                print('new candidats:', new_x_qehvi) \n",
    "                print('new obj:', new_y_qehvi)\n",
    "                print(\"Hypervolume (qEHVI):\", hv_truth[-1])\n",
    "                print(\"Time:\", t1 - t0)\n",
    "     \n",
    "            pd.DataFrame(Y_pf_train).to_csv(\"y_pareto_truth.csv\", header=None, index=None)\n",
    "            pd.DataFrame(data_x_qehvi).to_csv(\"data_x\"+str(j)+\".csv\", header=None, index=None)\n",
    "            pd.DataFrame(data_y_qehvi).to_csv(\"data_y\"+str(j)+\".csv\", header=None, index=None)\n",
    "            pd.DataFrame(hv_truth).to_csv(\"hv_truth.csv\", header=None, index=None)\n",
    "\n",
    "    # Save hv\n",
    "    hv_total.append(np.ravel(hv_truth))\n",
    "    pd.DataFrame(hv_total).to_csv(\"hv_truth_total.csv\", header=None, index=None)\n",
    "    \n",
    "    # Save edmin\n",
    "    edmin_total.append(np.ravel(edmin.cpu().detach().numpy()))\n",
    "    pd.DataFrame(edmin_total).to_csv(\"edmin_total.csv\", header=None, index=None)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "decf4358",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
