# Utility functions for Gaussian Processes (GP)
import torch
import gpytorch
import tqdm
import numpy as np
import pyro
from gpytorch.mlls import DeepApproximateMLL
from gpytorch.mlls import VariationalELBO, PredictiveLogLikelihood
from sklearn.linear_model import Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from xgboost import XGBRegressor

import utils.metrics as metrics
from models.exact_gp import ExactGPModel, MultitaskGPModel, FancyGPWithPriors, BatchIndependentMultitaskGPModel
from models.approximate_gp import ApproximateGPModel, ApproximateGPModel_NGD, MultitaskApproximateGPModel
from models.deep_gp import DeepGP, MultitaskDeepGP
from models.ensemble_gp import EnsembleGP
from models.dspp import TwoLayerDSPP
from models.base_sklearn import SklearnModelWrapper
from models.pyro_gp import FBGP
from models.fbgp import FBGP_gpytorch


def get_likelihood(args):
    """
    Returns the likelihood

    :param args: arguments
    """
    if args.outputs == 1:
        if args.model_type in ["gp_prior", 'fbgp_mcmc', 'fbgp_mcmc_gpytorch']:
            # Create a GaussianLikelihood with a normal prior for the noise
            noise_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=torch.sqrt(torch.tensor([3])))
            #noise_prior = gpytorch.priors.GammaPrior(2, 0.1)
            #noise_prior = gpytorch.priors.GammaPrior(1.1, 0.15)
            likelihood = gpytorch.likelihoods.GaussianLikelihood(
                #noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
                #noise_prior=gpytorch.priors.NormalPrior(0, 1)
                #noise_prior=gpytorch.priors.LogNormalPrior(loc=0, scale=2)
                noise_constraint=gpytorch.constraints.Positive(),
                noise_prior=noise_prior
                #noise_prior = gpytorch.priors.LogNormalPrior(loc=-3, scale=0.1)

            )
            #likelihood.noise = 0
            #likelihood.noise = (noise_prior.concentration - 1) / noise_prior.rate


        else:
            likelihood = gpytorch.likelihoods.GaussianLikelihood(
                noise_constraint=gpytorch.constraints.GreaterThan(1e-4)
            #    #noise_constraint=gpytorch.constraints.LessThan(1)
            )
            #likelihood.noise_covar.noise = 2e-4
            #likelihood = gpytorch.likelihoods.StudentTLikelihood()  # ExactGP can only handle Gaussian likelihoods..
    else:
        if args.model_type == "indep_exact":
            # Create a GaussianLikelihood with a normal prior for the noise
            likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
                num_tasks=args.outputs,
                noise_constraint=gpytorch.constraints.GreaterThan(1e-4),
                noise_prior=gpytorch.priors.NormalPrior(0, 1),
                has_global_noise=False
            )
            #likelihood.noise_covar.noise = 0.1
            likelihood.task_noises = 0.1
        else:
            likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=args.outputs)
    return likelihood


def get_model(args, data, kernel, likelihood, iteration=0, opt_hypers=None,
              length_prior=None, noise_prior=None):
    """
    Returns the model. All models will a specific likelihood defined inside the model class.

    :param args: arguments
    :param train_x: torch tensor with training data x (input)
    :param train_y: torch tensor with training data y (output)
    :param kernel: gpytorch kernel, e.g gpytorch.kernels.RBFKernel()
    :param likelihood: gpytorch likelihood, e.g. gpytorch.likelihoods.GaussianLikelihood()
    :param iteration: used to set priors TODO: Not tested
    :param opt_hypers: dict of parameters TODO: Not tested
    :param length_prior: prior to use in the PyroGP FBGP
    :param noise_prior: prior to use in  the PyroGP FBGP
    """

    train_x = data.train_trans.x
    train_y = data.train_trans.y

    if args.outputs == 1:
        if args.model_type == "heteroskedastic":
            #ind_points = torch.linspace(0, 1, args.inducing_points),
            ind_points = train_x
            model = ApproximateGPModel(train_x, train_y,
                                       kernel=kernel,
                                       likelihood=likelihood,
                                       inducing_points=ind_points)
        elif args.model_type == "ngd_model":
            # ind_points = torch.linspace(0, 1, args.inducing_points),
            ind_points = train_x
            model = ApproximateGPModel_NGD(train_x, train_y,
                                           kernel=kernel,
                                           likelihood=likelihood,
                                           inducing_points=ind_points)
        elif args.model_type == "dspp":
            train_x = train_x.unsqueeze(-1) if len(train_x.shape) == 1 else train_x
            ind_points = train_x
            model = TwoLayerDSPP(
                train_x.shape,
                inducing_points=ind_points,
                num_inducing=ind_points.shape[0],
                hidden_dim=args.hidden_dim,
                Q=8
            )
        elif args.model_type == "deepgp_ngd":
            # n_ind_points = args.inducing_points
            n_ind_points = train_x.shape[0]
            model = DeepGP(train_x.shape, num_output_dims=args.hidden_dim,
                           num_inducing=0, #n_ind_points,
                           inducing_points=train_x,
                           likelihood=likelihood, ngd=True)
        elif args.model_type == "deepgp":
            # Make sure the first dimension is the batch size
            train_x = train_x.unsqueeze(-1) if len(train_x.shape) == 1 else train_x
            # n_ind_points = args.inducing_points
            n_ind_points = train_x.shape[0]
            model = DeepGP(train_x.shape, num_output_dims=args.hidden_dim,
                           num_inducing=0, #n_ind_points,
                           inducing_points=train_x,
                           likelihood=likelihood, ngd=False)
        elif args.model_type == "hetero2":
            train_y_var = 0.1 + torch.norm(train_x, dim=0, keepdim=True) * torch.rand_like(train_y)
            #model = HeteroskedasticGPModel(train_x, train_y, train_y_var, kernel)
        elif args.model_type == "ensemble_exactgp":
            print("Ensemble")
            model = EnsembleGP(train_x, train_y, kernel, likelihood)
            #model = ExactGPModel(train_x, train_y, kernel, likelihood)
            #model = FancyGPWithPriors(train_x, train_y, likelihood)
            #model = ApproximateGPModel(train_x, train_y,
            #                           kernel=kernel,
            #                           likelihood=likelihood,
            #                           inducing_points=train_x)
            """
            # Make sure the first dimension is the batch size
            train_x = train_x.unsqueeze(-1) if len(train_x.shape) == 1 else train_x
            # n_ind_points = args.inducing_points
            n_ind_points = train_x.shape[0]
            model = DeepGP(train_x.shape, num_output_dims=args.hidden_dim,
                           num_inducing=n_ind_points, likelihood=likelihood, ngd=False)
            """
        elif args.model_type == "gp_prior":
            model = FancyGPWithPriors(train_x, train_y, likelihood)
        elif args.model_type == "exact":
            model = ExactGPModel(train_x, train_y, kernel, likelihood)
        elif args.model_type == "fbgp_mcmc":
            #pyro_kernel = pyro.contrib.gp.kernels.RBF(input_dim=train_x.shape[1])
            #if train_x.shape[1] > 1:
            #    pyro_kernel = pyro.contrib.gp.kernels.RBF(input_dim=1, active_dims=[0],
            #                                              variance=torch.tensor(1.), lengthscale=torch.tensor(0.69))
            #    for kernel_i in range(1, train_x.shape[1]):
            #        tmp_kernel = pyro.contrib.gp.kernels.RBF(input_dim=1, active_dims=[kernel_i],
            #                                                 variance=torch.tensor(1.), lengthscale=torch.tensor(0.69))
            #        pyro_kernel = pyro.contrib.gp.kernels.Sum(pyro_kernel, tmp_kernel)
            pyro_kernel = pyro.contrib.gp.kernels.RBF(input_dim=train_x.shape[1])
            #pyro_kernel.set_prior("lengthscale",
            #                      pyro.distributions.LogNormal(torch.tensor([0.]).repeat(train_x.shape[1]),
            #                                                   torch.tensor([1.73]).repeat(train_x.shape[1])).to_event())
            pyro_kernel.lengthscale = pyro.nn.PyroSample(
                pyro.distributions.LogNormal(torch.tensor([0.]).repeat(train_x.shape[1]),
                                             torch.sqrt(torch.tensor([3])).repeat(train_x.shape[1])).to_event())
            #length_prior = gpytorch.priors.GammaPrior(3, 6)<
            #noise_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=2)
            length_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=torch.sqrt(torch.tensor([3])).item())
            noise_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=torch.sqrt(torch.tensor([3])).item())
            #noise_prior = gpytorch.priors.LogNormalPrior(loc=-3, scale=0.1)
            #length_prior = gpytorch.priors.GammaPrior(1, 1)
            #noise_prior = gpytorch.priors.GammaPrior(2, 0.1)
            #noise_prior = pyro.distributions.HalfNormal(1)
            #noise_prior.arg_constraints['scale'] = torch.distributions.constraints._GreaterThan(1e-3)
            #noise_prior.support = torch.distributions.constraints._GreaterThan(1e-3)
            #noise_prior = gpytorch.priors.NormalPrior(0, 1)
            #noise_prior.arg_constraints['loc'] = torch.distributions.constraints._GreaterThan(1e-3)
            #noise_prior.support = torch.distributions.constraints._GreaterThan(1e-3)
            # NB: Prior must be the same as likelihood for exact_prior (pyro only defined on positive values)
            model = FBGP(args, train_x, train_y, pyro_kernel, length_prior=length_prior, noise_prior=noise_prior)
            #model.gpytorch_kernel = kernel
            model.gpytorch_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=data.train.x.shape[1])
            #model.gpytorch_kernel = gpytorch.kernels.RBFKernel(active_dims=torch.tensor([0])) + \
            #                        gpytorch.kernels.RBFKernel(active_dims=torch.tensor([1]))
            model.gpytorch_likelihood = likelihood
        elif args.model_type == "fbgp_mcmc_gpytorch":
            model = FBGP_gpytorch(args, train_x, train_y, likelihood)
        elif args.model_type == "ridge_reg":
            model = SklearnModelWrapper(model=Ridge(alpha=1.0))
        elif args.model_type == "xgboost":
            model = SklearnModelWrapper(model=XGBRegressor())
        else:
            raise NotImplementedError(f"Specified model '{args.model_type}' is not implemented..")

    else:
        # MultiTaskGP
        if args.model_type == "heteroskedastic":
            #n_ind_points = args.inducing_points
            ind_points = train_x
            model = MultitaskApproximateGPModel(train_x, train_y,
                                                kernel=kernel,
                                                likelihood=likelihood,
                                                inducing_points=ind_points,
                                                num_latents=args.outputs,
                                                num_tasks=args.outputs)
        elif args.model_type == "exact":
            model = MultitaskGPModel(train_x, train_y, kernel, likelihood=likelihood, num_tasks=args.outputs)
        elif args.model_type == "indep_exact":
            model = BatchIndependentMultitaskGPModel(train_x, train_y, likelihood=likelihood, batch_size=args.outputs)
        #elif args.model_type == "ridge_reg":
        #    model = SklearnModelWrapper(model=Ridge(alpha=0))
        elif args.model_type == "ridge_reg":
            model = make_pipeline(PolynomialFeatures(2), Ridge(alpha=0))
            model = SklearnModelWrapper(model=model)
        elif args.model_type == "deepgp":
            # Make sure the first dimension is the batch size
            train_x = train_x.unsqueeze(-1) if len(train_x.shape) == 1 else train_x
            # n_ind_points = args.inducing_points
            n_ind_points = train_x.shape[0]
            model = MultitaskDeepGP(train_x.shape, num_output_dims=args.hidden_dim,
                                    num_inducing=n_ind_points, likelihood=likelihood, num_tasks=args.outputs, ngd=False)

        else:
            raise NotImplementedError(f"Specified model '{args.model_type}' with multiple outputs is not implemented..")

    # Set prior
    if args.set_prior:
        if iteration == 0:
            model.initialize(**opt_hypers)
        else:  # priors from previous run
            model.load_state_dict(opt_hypers)

    return model


def get_loss(args, model, num_data):
    """
    Returns the loss

    :param args: arguments
    :param model: model
    :param num_data: number of data points in the training set
    :returns: loss function used to train model
    """
    # Define loss ("Loss" for GPs - the marginal log likelihood)
    if args.model_type == "heteroskedastic" or\
            args.model_type == "ngd_model":  # or args.model_type == "hetero2":
        # mll = DeepApproximateMLL(VariationalELBO(model.likelihood, model, num_data))
        mll = gpytorch.mlls.PredictiveLogLikelihood(model.likelihood, model, num_data=num_data)
    elif args.model_type == "deepgp" or args.model_type == "deepgp_ngd":
        # mll = DeepApproximateMLL(VariationalELBO(model.likelihood, model, num_data))
        mll = DeepApproximateMLL(PredictiveLogLikelihood(model.likelihood, model, num_data))
    elif args.model_type in ['ridge_reg', 'xgboost']:
        mll = None
    elif args.model_type in ['fbgp_mcmc']:
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.pred_model.likelihood, model.pred_model)
    else:  # Exact, homoscedastic, hetero2?
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
    return mll


def validate_gpytorch_model2(args, gp_test_loader, model, mll):
    """
    Validation of GPyTorch model on test set
    Metrics:
    - Negative Marginal Log Likelihood <> compare to the test set
    - RMSE <> compare to the test set... but it would be more precise to compare to true mean, variance, quantiles, ect.

    :param args: arguments
    :param gp_test_loader: a data loader suited for either DeepGP (batches) or other GPs (non batches)
    :param model: gpytorch model
    :param mll: marginal log likelihood
    :returns: losses
    """
    # Get into evaluation (predictive posterior) mode
    model.eval()
    if args.model_type == "deepgp" or args.model_type == "deepgp_ngd":
        # Validation NB: Right now, only a single batch !!! but batch_size = #test points..
        with torch.no_grad(), gpytorch.settings.num_likelihood_samples(args.n_samples):
            nmll_losses_valid_batch = []
            rmse_losses_valid_batch = []
            for x_batch, y_batch in gp_test_loader:
                output, _ = model(x_batch)
                predictions = model.likelihood(output)
                nmll_losses_valid_batch.append(torch.mean(-mll(predictions, y_batch)).item())
                rmse_losses_valid_batch.append(metrics.rmse(y_batch, predictions.mean).item())
            nmll_loss_valid = np.mean(nmll_losses_valid_batch)
            rmse_loss_valid = np.mean(rmse_losses_valid_batch)

    else:
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            test_x, test_y = gp_test_loader
            predictions = model.likelihood(model(test_x))
            nmll_loss_valid = torch.mean(-mll(predictions, test_y))
            rmse_loss_valid = metrics.rmse(test_y, predictions.mean)

    return nmll_loss_valid, rmse_loss_valid


def compute_loss(args, dataloader, predictions, lst_metrics=[], mll=None):
    """
    Validation of predictions: compares the labels in the test loader with the predictions.

    :param args: arguments
    :param dataloader: a data loader suited for either DeepGP (batches) or other GPs (non batches)
    :param predictions: predictions
    :param lst_metrics: list of metrics to calculate
    :param mll: marginal log likelihood (must be given together with mll in lst_metrics in order to calculate mll
    :returns: dictionary with losses
    """

    # Check if dataloader is a PyTorch dataloader or a tuple
    # If it is a PyTorch dataloader, it uses batches
    batches = False if type(dataloader) == tuple else True

    losses = {'nmll': -1}
    if batches:
        test_y = torch.tensor([])
        nmll_losses_valid_batch = []
        for idx, batch in enumerate(dataloader):
            _, y_batch = batch
            if 'mll' in lst_metrics and mll is not None and args.model_type not in ['ridge_reg', 'xgboost']:
                nmll_losses_valid_batch.append(torch.mean(-mll(predictions[idx], y_batch)).item())
        if 'mll' in lst_metrics and mll is not None and args.model_type not in ['ridge_reg', 'xgboost']:
            losses['nmll'] = np.mean(nmll_losses_valid_batch).item()
    else:
        test_x, test_y = dataloader
        if 'mll' in lst_metrics and mll is not None and args.model_type not in ['ridge_reg', 'xgboost']:
            losses['nmll'] = torch.mean(-mll(predictions, test_y)).item()

    if 'rmse' in lst_metrics:
        losses['rmse'] = metrics.rmse(preds=predictions.view(test_y.shape), targets=test_y)
    if 'mae' in lst_metrics:
        losses['mae'] = metrics.mae(preds=predictions.view(test_y.shape), targets=test_y)
    if 'rae' in lst_metrics:
        losses['rae'] = metrics.rea(preds=predictions.view(test_y.shape), targets=test_y)
    if 'rrse' in lst_metrics:
        losses['rrse'] = metrics.rrse(preds=predictions.view(test_y.shape), targets=test_y)

    return losses
