# Constructor functions for robust GP models for benchmarking

import math
from functools import partial
from time import perf_counter
from typing import List, Mapping, Optional, Tuple, Union

import gpytorch
import numpy as np

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize

from gpytorch.kernels import MaternKernel, ScaleKernel

from gpytorch.likelihoods import GaussianLikelihood, Likelihood, StudentTLikelihood
from gpytorch.likelihoods.noise_models import HomoskedasticNoise
from gpytorch.means import ConstantMean
from gpytorch.mlls import VariationalELBO
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import NormalPrior
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior
from gpytorch.variational import VariationalStrategy
from sklearn.preprocessing import PowerTransformer
from torch import Tensor

from ..constraints import LogTransformedInterval
from ..relevance_pursuit import (
    backward_relevance_pursuit,
    forward_relevance_pursuit,
    get_posterior_over_support,
)
from .sparse_outlier_noise import SparseOutlierGaussianLikelihood, SparseOutlierNoise
from .trimmed_mll import getFinalModelForPrediction, residualNuTrimmedGP


def get_gp_modules(dim: int) -> Tuple[ConstantMean, ScaleKernel]:
    mean_module = ConstantMean(constant_prior=NormalPrior(0.0, 1.0))
    covar_module = ScaleKernel(
        MaternKernel(
            ard_num_dims=dim,
            lengthscale_constraint=LogTransformedInterval(
                0.1, 100.00, initial_value=1.0
            ),
            lengthscale_prior=LogNormalPrior(
                loc=math.sqrt(2) + math.log(dim) / 2, scale=math.sqrt(3)
            ),
        ),
        outputscale_prior=GammaPrior(concentration=2.0, rate=0.15),
    )
    return mean_module, covar_module


def get_model(
    likelihood: Likelihood,
    X: Tensor,
    Y: Tensor,
    outcome_transform: Optional[OutcomeTransform] = None,
    input_transform: Optional[InputTransform] = None,
) -> SingleTaskGP:
    mean_module, covar_module = get_gp_modules(dim=X.shape[-1])
    return SingleTaskGP(
        train_X=X,
        train_Y=Y,
        mean_module=mean_module,
        covar_module=covar_module,
        likelihood=likelihood,
        outcome_transform=outcome_transform,
        input_transform=input_transform,
    )


def get_vanilla_model(
    X: Tensor,
    Y: Tensor,
    outcome_transform: Optional[OutcomeTransform] = None,
    input_transform: Optional[InputTransform] = None,
    min_noise: float = 1e-6,
    max_noise: float = 1.0,
    timeout_sec: Optional[float] = None,
) -> SingleTaskGP:
    # standardize by default
    if outcome_transform is None:
        outcome_transform = Standardize(m=Y.shape[-1])

    likelihood = GaussianLikelihood(
        noise_constraint=LogTransformedInterval(min_noise, max_noise),
        noise_prior=GammaPrior(concentration=0.9, rate=10.0),
    )
    model = get_model(
        likelihood=likelihood,
        X=X,
        Y=Y,
        outcome_transform=outcome_transform,
        input_transform=input_transform,
    )
    mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
    fit_gpytorch_mll(mll, optimizer_kwargs={"timeout_sec": timeout_sec})
    return model


def get_robust_model(
    X: Tensor,
    Y: Tensor,
    numbers_of_outliers: Optional[Union[int, List[int]]] = None,
    fractions_of_outliers: Optional[List[float]] = None,
    fit_inlier_only_model: bool = False,
    min_noise: float = 1e-6,
    max_noise: float = 1.0,
    initial_support: Optional[List[int]] = None,
    prior_mean_of_support: Optional[int] = None,
    convex_parameterization: bool = True,
    timeout_sec: Optional[float] = None,
    input_transform: Optional[InputTransform] = None,
    outcome_transform: Optional[OutcomeTransform] = None,
    use_forward_algorithm: bool = False,
    select_posterior_mean_of_support: bool = False,
    reset_parameters: bool = True,
    reset_dense_parameters: bool = False,
) -> SingleTaskGP:
    """Trains a robust GP model.

    Args:
        X: The training inputs.
        Y: The training outputs.
        numbers_of_outliers: An optional list of numbers of outliers to consider
            during the relevance pursuit algorithm. If None, the algorithm will
            default to considering len(Y) * fractions_of_outliers. IDEA: this could be
            generalized to type of bisection search over the number of outliers, or one
            could use exponential_sparsity_levels(n=n, base=2, reverse=True).
        fractions_of_outliers: An optional list of fractions of outliers to consider if
            numbers_of_outliers is None. By default: [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5].
        fit_inlier_only_model: If True, the model will be fit to the inlier data
            only, after running the outlier detection algorithm.
        min_noise: The lower bound for the noise level.
        max_noise: The upper bound for the noise level.
        initial_support: A list of indices of the initial support set.
        prior_mean_of_support: The mean value for the default exponential prior
            distribution over the support size.
        convex_parameterization: If True, the convex parameterization of the outlier
            variances is used, which allows us to use looser convergence tolerances as
            the optimizer does not have to escape flat non-convex regions.
        use_forward_algorithm: Whether or not to use the forward algorithm.
        select_posterior_mean_of_support: If True, we will choose the model whose
            support is closest to the posterior mean of the support size. Otherwise, we
            will choose the maximum-a-posteriori (MAP) estimate.
        reset_parameters: If True, we will reset the parameters of the model after each
            iteration of the relevance pursuit algorithm.
    """
    # standardize by default
    if outcome_transform is None:
        outcome_transform = Standardize(m=Y.shape[-1])

    base_noise = HomoskedasticNoise(
        noise_constraint=LogTransformedInterval(min_noise, max_noise),
        noise_prior=GammaPrior(concentration=0.9, rate=10.0),
    )
    likelihood = SparseOutlierGaussianLikelihood(
        base_noise=base_noise,
        dim=X.shape[0],
        convex_parameterization=convex_parameterization,
    )
    model = get_model(
        likelihood=likelihood,
        X=X,
        Y=Y,
        input_transform=Normalize(d=X.shape[-1]),
        outcome_transform=Standardize(m=1),
    )
    mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
    sparse_module = model.likelihood.noise_covar

    if numbers_of_outliers is None:
        if fractions_of_outliers is None:
            fractions_of_outliers = [
                0.0,
                0.05,
                0.1,
                0.15,
                0.2,
                0.3,
                0.4,
                0.5,
                0.75,
                1.0,
            ]

        # list from which BMC chooses
        n = len(initial_support) if initial_support is not None else len(Y)
        numbers_of_outliers = [int(p * n) for p in fractions_of_outliers]

    # force number_of_outliers to be a list
    if isinstance(numbers_of_outliers, int):
        numbers_of_outliers = [numbers_of_outliers]

    if prior_mean_of_support is None:
        prior_mean_of_support = int(len(Y) * 0.2)

    tol = 1e-9
    optimizer_kwargs: Mapping[str, Union[float, Mapping[str, Union[int, float]]]] = {
        "options": {"maxiter": 1024, "ftol": tol, "gtol": tol},
    }
    if timeout_sec is not None:
        optimizer_kwargs["timeout_sec"] = timeout_sec / len(numbers_of_outliers)
    optimize = (
        forward_relevance_pursuit
        if use_forward_algorithm
        else backward_relevance_pursuit
    )
    sparse_module, model_trace = optimize(
        sparse_module=sparse_module,
        mll=mll,
        sparsity_levels=numbers_of_outliers,
        reset_parameters=reset_parameters,
        reset_dense_parameters=reset_dense_parameters,
        optimizer_kwargs=optimizer_kwargs,
        initial_support=initial_support,
        record_model_trace=True,
    )

    # Bayesian model comparison
    support_size, bmc_probabilities = get_posterior_over_support(
        SparseOutlierNoise, model_trace, prior_mean_of_support=prior_mean_of_support
    )
    mean_support = (support_size * bmc_probabilities).sum()
    if select_posterior_mean_of_support:
        map_index = (mean_support - support_size).abs().argmin()
    else:
        map_index = torch.argmax(bmc_probabilities)

    model = model_trace[map_index]  # choosing model with highest BMC probability
    sparse_module = model.likelihood.noise_covar
    # print(f"{bmc_probabilities = }")
    # print(f"{len(sparse_module.support) = }")
    # print(f"{mean_support=}")

    if fit_inlier_only_model:
        X_inlier = X[~sparse_module.is_active]
        Y_inlier = Y[~sparse_module.is_active]
        likelihood = GaussianLikelihood(
            noise_constraint=LogTransformedInterval(min_noise, max_noise),
            noise_prior=GammaPrior(concentration=0.9, rate=10.0),
        )
        inlier_model = get_model(
            X=X_inlier,
            Y=Y_inlier,
            likelihood=likelihood,
            outcome_transform=Standardize(m=Y.shape[-1]),
        )
        mll = ExactMarginalLogLikelihood(
            model=inlier_model, likelihood=inlier_model.likelihood
        )
        fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
        model = mll.model

    return model


def get_student_t_model(
    X: Tensor,
    Y: Tensor,
    timeout_sec: Optional[float] = None,
) -> SingleTaskVariationalGP:
    likelihood = StudentTLikelihood()
    _, covar_module = get_gp_modules(dim=X.shape[-1])
    vs = partial(VariationalStrategy, learn_inducing_locations=False)
    model = SingleTaskVariationalGP(
        train_X=X,
        train_Y=Y,
        # For some reason using the same mean_module from above causes model to badly underfit,
        # even with Gaussian likelihood.
        # mean_module=mean_module,
        covar_module=covar_module,
        likelihood=likelihood,
        inducing_points=X,
        variational_strategy=vs,
        learn_inducing_points=False,
        input_transform=Normalize(d=X.shape[-1]),
        outcome_transform=Standardize(m=1),
    )
    mll = VariationalELBO(model.likelihood, model.model, num_data=X.shape[-2])
    fit_gpytorch_mll(mll, optimizer_kwargs={"timeout_sec": timeout_sec})
    return model


def get_trimmed_mll_model(
    X: torch.Tensor,
    Y: torch.Tensor,
    start_nu: float = 0.5,
    timeout_sec: Optional[float] = None,
) -> SingleTaskGP:
    N = X.shape[-2]
    max_outliers = math.ceil(N * start_nu)
    previous_max_outliers = float("inf")
    start = perf_counter()

    with gpytorch.settings.cholesky_jitter(
        float_value=1e-6, double_value=1e-6
    ), gpytorch.settings.cholesky_max_tries(
        10
    ), gpytorch.settings.variational_cholesky_jitter(
        float_value=1e-6, double_value=1e-6
    ):

        base_model_constructor = partial(
            get_model, outcome_transform=Standardize(m=Y.shape[-1])
        )
        while (previous_max_outliers > max_outliers) and (
            (timeout_sec is None) or (perf_counter() - start < timeout_sec)
        ):
            max_outliers, sigmaEstimate = residualNuTrimmedGP(
                full_X=X,
                full_y=Y.flatten(),
                maxNrOutlierSamples=max_outliers,
                base_model_constructor=base_model_constructor,
            )
            previous_max_outliers = max_outliers

        gpModel = getFinalModelForPrediction(
            X,
            Y.flatten(),
            max_outliers,
            sigmaEstimate=sigmaEstimate,
            method="projectedGradient",
            base_model_constructor=base_model_constructor,
        )
        gpModel.eval()
        gpModel.likelihood.eval()
        return gpModel


def get_winsorized_model(X: Tensor, Y: Tensor, winsorize_lower: bool = False):
    Y = Y.clone()  # Since we will modify in-place
    q1 = torch.quantile(Y, q=0.25)
    q3 = torch.quantile(Y, q=0.75)
    iqr = q3 - q1
    if winsorize_lower:  # From below
        threshold = q1 - 1.5 * iqr
        Y[Y < threshold] = threshold
    else:  # From above
        threshold = q3 + 1.5 * iqr
        Y[Y > threshold] = threshold
    return get_vanilla_model(X=X, Y=Y, outcome_transform=Standardize(m=1))


def get_power_transformed_model(
    X: Tensor, Y: Tensor, return_power_transformer: bool = False
):
    """Power transform the data using a Yeo-Johnson transform.

    This method can optionally return the `power_transformer` which is needed to
    apply the inverse power transform."""
    y = np.array(Y.cpu())
    power_transformer = PowerTransformer(method="yeo-johnson").fit(y)
    Y = torch.tensor(power_transformer.transform(y), device=Y.device, dtype=Y.dtype)
    model = get_vanilla_model(X=X, Y=Y, outcome_transform=Standardize(m=1))
    if return_power_transformer:
        return model, power_transformer
    return model
