from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.constraints import GreaterThan
from botorch.optim import optimize_acqf

from botorch.acquisition.monte_carlo import qExpectedImprovement, qNoisyExpectedImprovement, qProbabilityOfImprovement, qUpperConfidenceBound
# from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.exceptions import BadInitialCandidatesWarning

import time
import warnings
import torch
from torch.optim import SGD


# warnings.filterwarnings('ignore', category=BadInitialCandidatesWarning)
# warnings.filterwarnings('ignore', category=RuntimeWarning)

# def gp_initialize_model(train_x, train_y, device, state_dict=None):
#     train_x = torch.from_numpy(train_x).to(device)
#     train_y = torch.from_numpy(train_y).to(device)

#     model = SingleTaskGP(train_X=train_x, train_Y=train_y).to(device)
#     model.likelihood.noise_covar.register_constraint("raw_noise", GreaterThan(1e-5))

#     mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
#     mll = mll.to(train_x)

#     # load state dict if it is passed
#     if state_dict is not None:
#         model.load_state_dict(state_dict)
#     fit_gpytorch_model(mll)
#     return mll.model

def gp_train_model(train_x, train_y, device, state_dict=None):
    train_x = torch.from_numpy(train_x).to(device)
    train_y = torch.from_numpy(train_y).to(device)

    model = SingleTaskGP(train_X=train_x, train_Y=train_y).to(device)
    model.likelihood.noise_covar.register_constraint("raw_noise", GreaterThan(1e-5))

    mll = ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
    mll = mll.to(train_x)

    optimizer = SGD([{'params': model.parameters()}], lr=0.1)
    model.train()

    NUM_EPOCHS=1000 #300 for superconductor

    for epoch in range(NUM_EPOCHS):
            # clear gradients
            optimizer.zero_grad()
            # forward pass through the model to obtain the output MultivariateNormal
            output = model(train_x)
            # Compute negative marginal log likelihood
            loss = - mll(output, train_y.view(-1))
            # back prop gradients
            loss.backward()
            # print every 10 iterations
            # if (epoch + 1) % 50 == 0:
                # print(loss)
                    # f"Epoch {epoch+1:>3}/{NUM_EPOCHS} - Loss: {loss.item():>4.3f} "
                    # f"lengthscale: {model.covar_module.base_kernel.lengthscale.item():>4.3f} " 
                    # f"noise: {model.likelihood.noise.item():>4.3f}" 
                # )
            optimizer.step()

    # load state dict if it is passed
    if state_dict is not None:
        model.load_state_dict(state_dict)
    
    return model

def gp_optimize_acqf_and_get_observation(acq_func, model, best_f, bounds, batch_size, device):
    """Optimizes the acquisition function, and returns a new candidate and a noisy observation."""

    best_f = torch.tensor(best_f).to(device)
    bounds = torch.from_numpy(bounds).to(device)

     # define the qEI acquisition modules using a QMC sampler
    qmc_sampler = SobolQMCNormalSampler(num_samples=100)

    if acq_func == 'qei':
        acq = qExpectedImprovement(
            model=model, 
            best_f=best_f,
            sampler=qmc_sampler, 
        )
    
    elif acq_func == 'qpi':
        acq = qProbabilityOfImprovement(
            model=model, 
            best_f=best_f,
            sampler=qmc_sampler, 
        )
    elif acq_func == 'qucb':
        acq = qUpperConfidenceBound(
            model=model, 
            beta=1,
            sampler=qmc_sampler, 
        )

    candidates, _ = optimize_acqf(
        acq_function=acq,
        bounds=bounds,
        q=batch_size,
        num_restarts=30,
        raw_samples=500,  # used for intialization heuristic
        options={"batch_limit": 50, "maxiter": 200},
    )
    # observe new values 
    new_x = candidates.detach().cpu().numpy()
    return new_x