"""
Implements the method from
Andrade D, Takeda A, Robust Gaussian process regression with the trimmed marginal
likelihood, UAI 2023.
Follows their code implementation with simplification to focus on only the primary
method from the paper, and with converting base gpytorch models into Botorch models,
and handling double dtypes instead of only float dtypes.
"""

import math
from typing import Tuple

import gpytorch
import numpy as np
import scipy
import torch
from sklearn.model_selection import KFold


def residualNuTrimmedGP(
    full_X: torch.Tensor,
    full_y: torch.Tensor,
    maxNrOutlierSamples: int,
    base_model_constructor,
) -> Tuple[int, float]:
    assert full_X.shape[0] == full_y.shape[0]
    FULL_N = full_X.shape[0]
    NU = maxNrOutlierSamples / FULL_N
    MIN_NR_INLIERS = FULL_N - maxNrOutlierSamples
    NR_FOLDS = 10

    NU_FOR_CV = NU / (1.0 - (1.0 / NR_FOLDS))

    cv = KFold(n_splits=NR_FOLDS, random_state=4323, shuffle=True)
    residuals_abs = torch.zeros(FULL_N, dtype=full_X.dtype, device=full_X.device)

    for _, (train_index, valid_index) in enumerate(cv.split(full_X)):
        cv_maxNrOutliers = int(train_index.shape[0] * NU_FOR_CV)
        gpModel = trainTrimmedGP(
            full_X[train_index, :],
            full_y[train_index],
            cv_maxNrOutliers,
            optimizationMode="projectedGradient",
            base_model_constructor=base_model_constructor,
        )

        mean_predictions = getModelMeanPrediction(
            model=gpModel, X=full_X[valid_index, :]
        )
        assert valid_index.shape[0] == mean_predictions.shape[0]
        residuals_abs[valid_index] = torch.abs(mean_predictions - full_y[valid_index])

    inlierAbsDiff, _ = torch.sort(residuals_abs)[0:MIN_NR_INLIERS]
    sigmaEstimate = getAsymptoticCorrectedSigma(inlierAbsDiff, FULL_N)

    maxNrOutlierSamples_new = torch.sum(residuals_abs > sigmaEstimate * 2).item()

    return maxNrOutlierSamples_new, sigmaEstimate


def getAsymptoticCorrectedSigma(inlierAbsDiff, n) -> float:
    m = inlierAbsDiff.shape[0]
    inlierRatio = m / n
    if m == n:
        return torch.sqrt(torch.mean(torch.square(inlierAbsDiff))).item()
    else:
        correctionFactor = 1.0 / scipy.stats.chi2.ppf(inlierRatio, df=1.0)
        empiricalQuantile = torch.max(inlierAbsDiff)
        correctedSigma = empiricalQuantile.item() * math.sqrt(correctionFactor)
        return correctedSigma


def trainTrimmedGP(
    X: torch.Tensor,
    y: torch.Tensor,
    maxNrOutlierSamples: int,
    optimizationMode: str,
    base_model_constructor,
):
    assert optimizationMode == "projectedGradient"
    assert X.shape[0] == y.shape[0]

    # initialize likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = base_model_constructor(likelihood, X, y.unsqueeze(-1))
    model.likelihood.noise = 10.0

    # Find optimal model hyperparameters
    model.train()
    model.likelihood.train()

    # Use the adam optimizer
    def getOptimizer(model):
        return torch.optim.Adam(
            model.parameters(), lr=0.1
        )  # Includes GaussianLikelihood parameters

    optimizer = getOptimizer(model)

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

    # ***** proposed Algorthm for joint optimization, corresponds to Algorithm 1 in paper *****

    previous_inlier_ids = None
    previous_loss = torch.inf

    i = 0
    while True:
        assert model.training and model.likelihood.training
        allInlierSamplesIds, allOutlierSampleIds = (
            getInlierAndOutliersBasedOnMarginalLikelihood_hardthresholding(
                model, model.likelihood, X, y, maxNrOutlierSamples
            )
        )

        allInlierSamplesIds = torch.sort(allInlierSamplesIds)[0]
        assert allOutlierSampleIds.shape[0] == maxNrOutlierSamples

        if (previous_inlier_ids is None) or not torch.equal(
            allInlierSamplesIds, previous_inlier_ids
        ):
            # print("** inlier set changed ** iteration nr = ", i)

            new_loss = set_new_data_and_get_loss(model, mll, X, y, allInlierSamplesIds)
            if new_loss >= previous_loss:
                # new set does not improve marginal likelihood -> fallback on previous inliers
                assert previous_inlier_ids is not None
                allInlierSamplesIds = previous_inlier_ids
            else:
                previous_inlier_ids = allInlierSamplesIds
                optimizer = getOptimizer(model)

            model.set_train_data(
                inputs=X[allInlierSamplesIds, :],
                targets=y[allInlierSamplesIds],
                strict=False,
            )

        optimizer.zero_grad()  # Zero gradients from previous iteration
        # Calc loss and backprop gradients
        previous_loss = -mll(
            model(X[allInlierSamplesIds, :]), y[allInlierSamplesIds]
        )  # note: mll returns marginal log-likelihood divided by the number of samples
        previous_loss.backward()
        optimizer.step()

        new_loss = -mll(model(X[allInlierSamplesIds, :]), y[allInlierSamplesIds]).item()

        i = i + 1

        if new_loss >= previous_loss:
            break

    return model


def getInlierAndOutliersBasedOnMarginalLikelihood_hardthresholding(
    model, likelihood, X, y, maxNrOutlierSamples
):
    assert model.training and likelihood.training

    covMatrix = likelihood(model.forward(X)).covariance_matrix
    orig_K = covMatrix.detach()

    n = orig_K.shape[0]
    invOrigK = gpytorch.root_inv_decomposition(orig_K)

    try:
        eig_values, _ = gpytorch.diagonalization(orig_K)
    except torch._C._LinAlgError:
        try:
            eig_values, _ = scipy.sparse.linalg.eigsh(
                orig_K.detach().numpy(), k=1, which="SM"
            )
        except scipy.sparse.linalg._eigen.arpack.ArpackNoConvergence:
            eig_values, _ = np.linalg.eigh(orig_K.detach().numpy())

    smallest_eigenvalue_origK = eig_values[0]

    lipschitzConstant = 2.0 * (1.0 / smallest_eigenvalue_origK)
    invLipschitzConstant = 1.0 / lipschitzConstant
    b = -y

    previous_b = torch.zeros(n, dtype=X.dtype, device=X.device)

    for _ in range(200):
        grad_b = 2.0 * (invOrigK @ (y + b))

        b = b - invLipschitzConstant * grad_b
        allInlierSamplesIds = project_b(b, maxNrOutlierSamples)

        diff_to_previous = torch.sum(torch.square(previous_b - b))
        previous_b = torch.clone(b)

        if diff_to_previous < 0.00000001:
            break

    allOutlierSampleIds = np.arange(n)
    allOutlierSampleIds = np.delete(
        allOutlierSampleIds, allInlierSamplesIds.cpu().numpy()
    )

    allOutlierSampleIds = torch.from_numpy(allOutlierSampleIds)
    allOutlierSampleIds = allOutlierSampleIds.to(device=X.device)
    allInlierSamplesIds = allInlierSamplesIds.to(device=X.device)

    assert allInlierSamplesIds.shape[0] == n - maxNrOutlierSamples
    assert allOutlierSampleIds.shape[0] == maxNrOutlierSamples

    return allInlierSamplesIds, allOutlierSampleIds


def project_b(b, maxNrOutlierSamples):
    n = b.shape[0]
    assert len(b.shape) == 1
    zero_ids = torch.argsort(torch.abs(b))[0 : (n - maxNrOutlierSamples)]
    b[zero_ids] = 0.0
    return zero_ids


def set_new_data_and_get_loss(model, mll, X, y, allInlierSamplesIds):

    dataForTrainingHyperparameters_X = X[allInlierSamplesIds, :]
    dataForTrainingHyperparameters_y = y[allInlierSamplesIds]
    model.set_train_data(
        inputs=dataForTrainingHyperparameters_X,
        targets=dataForTrainingHyperparameters_y,
        strict=False,
    )

    # Calc loss and backprop gradients
    loss = -mll(
        model(dataForTrainingHyperparameters_X), dataForTrainingHyperparameters_y
    )  # note: mll returns marginal log-likelihood divided by the number of samples

    return loss.item()


def getPredictions(model, likelihood, X):

    model.eval()
    likelihood.eval()

    # number of samples used for estimating the integral of the liklihood =  int_f p(y | f) p(f) df,
    # where p(f) is a multivariate gaussian, and p(y | f) is the likelihood (e.g. student t)
    with torch.no_grad(), gpytorch.settings.num_likelihood_samples(100):
        predictive_distribution = likelihood(model(X))

    return predictive_distribution


def getMeanPredictions(predictive_distribution):
    meanPredictions = predictive_distribution.loc.detach()
    if len(meanPredictions.shape) == 2:
        meanPredictions = torch.mean(meanPredictions, axis=0)
    return meanPredictions


def getModelMeanPrediction(model, X):
    if len(X.shape) == 1:
        X = X.view(1, -1)

    predictive_distribution_at_X = getPredictions(model, model.likelihood, X)
    return getMeanPredictions(predictive_distribution_at_X)


def getFinalModelForPrediction(
    full_X: torch.Tensor,
    full_y: torch.Tensor,
    maxNrOutlierSamples: int,
    sigmaEstimate: float,
    method,
    base_model_constructor,
):

    gpModel = trainTrimmedGP(
        full_X, full_y, maxNrOutlierSamples, method, base_model_constructor
    )

    if sigmaEstimate > math.sqrt(1.000e-04):
        gpModel.likelihood.noise = sigmaEstimate**2
    else:
        gpModel.likelihood.noise = 1.000e-04

    return gpModel
