import gpytorch
import numpy as np
import torch
from botorch.fit import fit_gpytorch_model
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from gpytorch.constraints.constraints import Interval
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood


def train_gp(_observed_x, _observed_f) -> ModelListGP:
    # Note: the GP is trained as with -_observed_f, this is to be consistent with the qHVI acqf
    _dim = _observed_x.shape[-1]
    _num_objectives = _observed_f.shape[-1]
    train_x_unit = torch.tensor(_observed_x.reshape(-1, _dim), dtype=torch.double)
    train_obj = torch.tensor(_observed_f.reshape(-1, _num_objectives), dtype=torch.double)
    assert train_x_unit.min() >= 0.0 and train_x_unit.max() <= 1.0
    models = []
    for i in range(train_obj.shape[-1]):
        train_y = train_obj[..., i : i + 1]
        likelihood = GaussianLikelihood(
            noise_constraint=Interval(1e-6, 1e-3)
        )
        base_kernel = gpytorch.kernels.MaternKernel(
            nu=2.5, 
            ard_num_dims=_dim,
            lengthscale_constraint=Interval(np.sqrt(1e-3), np.sqrt(1e3)),
        )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=base_kernel,
            outputscale_constraint=Interval(np.sqrt(1e-3), np.sqrt(1e3)),
        )
        model = SingleTaskGP(
                train_X=train_x_unit.clone(), 
                train_Y=-train_y.clone(), 
                likelihood=likelihood,
                covar_module=covar_module,
            )
        # initialize the hyperparameters
        model.covar_module.base_kernel.lengthscale = torch.ones((1, _dim))
        model.covar_module.outputscale = torch.tensor(1.0)
        model.likelihood.noise_covar.noise = torch.ones(1) * 1e-4
        models.append(model)
    model_list_for_max = ModelListGP(*models)
    mll = SumMarginalLogLikelihood(model_list_for_max.likelihood, model_list_for_max)

    # train GP
    fit_gpytorch_model(mll)
    return model_list_for_max

def estimate_mean_and_std(_x, _model_list: ModelListGP):
    with torch.no_grad(), gpytorch.settings.cholesky_jitter(double=1e-1):
        x_torch = torch.tensor(np.atleast_2d(_x), dtype=torch.double)
        posterior = _model_list.posterior(x_torch)
        mean_pred = -posterior.mean.numpy()
        std_pred = posterior.variance.clamp_min(1e-9).sqrt().numpy()
    if _x.ndim == 1 and mean_pred.ndim == std_pred.ndim == 2 and mean_pred.shape[0] == std_pred.shape[0] == 1:
        return mean_pred[0], std_pred[0]
    return mean_pred, std_pred