
import torch
from botorch.optim import optimize_acqf
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch.utils.transforms import normalize, unnormalize
from botorch.models.transforms import Standardize

from utils import device

# d = 20

def get_fitted_model(train_x, train_obj, d, state_dict=None):
    # initialize and fit model
    model = SingleTaskGP(
        train_X=normalize(train_x.to(device), torch.stack(
            [
                torch.zeros(d, device=device),
                torch.ones(d, device=device),
            ]
        )),
        train_Y=train_obj.to(device)
    ).to(device)

    if state_dict is not None:
        model.load_state_dict(state_dict)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    mll.to(train_x)
    fit_gpytorch_mll(mll)
    return model

def optimize_acqf_and_get_observation(acq_func, d, BATCH_SIZE, NUM_RESTARTS, RAW_SAMPLES):
    """Optimizes the acquisition function, and returns a
    new candidate and a noisy observation"""

    # optimize
    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=torch.stack(
            [
                torch.zeros(d, device=device),
                torch.ones(d, device=device),
            ]
        ),
        q=BATCH_SIZE,
        num_restarts=NUM_RESTARTS,
        raw_samples=RAW_SAMPLES,
    )

    # observe new values
    new_z = unnormalize(candidates.detach(), bounds=torch.stack(
            [
                torch.zeros(d, device=device),
                torch.ones(d, device=device),
            ]
        ))
    # new_z = candidates.detach()

    return new_z