# Ensemble Gaussian Processes
import gpytorch
import torch
import numpy as np

from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution,\
    NaturalVariationalDistribution, TrilNaturalVariationalDistribution
from models.exact_gp import ExactGPModel, FancyGPWithPriors
from models.approximate_gp import ApproximateGPModel
from models.base_exact_gp import BaseExactGPModel


class EnsembleGP(BaseExactGPModel):
    """
    EnsembleGP

    """
    def __init__(self, train_x, train_y, kernel, likelihood, base_estimator=ExactGPModel,
                 n_estimators=2):
        super(EnsembleGP, self).__init__(train_x, train_y, likelihood)

        # Initialization
        self.base_estimator = base_estimator  # Not in use. Always using ExactGP
        self.n_estimators = n_estimators

        # Settings
        self.combination_rule = "naive"
        self.bagging_fraction = 1 #/ n_estimators if bagging_fraction is None else bagging_fraction
        self.replacement = False

        # Define models
        self.subset_idx = []
        self.models = []

        likelihood = [
            gpytorch.likelihoods.GaussianLikelihood(
            noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
            noise_prior=gpytorch.priors.NormalPrior(0,.001)
            ),
            gpytorch.likelihoods.GaussianLikelihood(
            noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
            noise_prior=gpytorch.priors.NormalPrior(1, 1)
            )
        ]
        for i in range(n_estimators):
            # get subset of data by
            size_subset = int(len(train_x) * self.bagging_fraction)
            idx = np.random.choice(np.arange(0, len(train_x)), size=size_subset, replace=self.replacement)
            train_x_subset, train_y_subset = train_x[idx], train_y[idx]
            #self.models.append(ExactGPModel(train_x_subset, train_y_subset, kernel, likelihood))
            #self.models.append(FancyGPWithPriors(train_x_subset, train_y_subset, likelihood[i]))
            self.models.append(ApproximateGPModel(train_x_subset, train_y_subset, kernel, likelihood[i],
                                                  inducing_points=train_x_subset))
            self.subset_idx.append(idx)

    def forward(self, x):
        # Get predictions from each model
        predictions, means, covars = [], [], []
        for i in range(self.n_estimators):
            model_i = self.models[i]
            predictions.append(model_i(x))
            means.append(predictions.mean)
            covars.append(predictions.covariance_matrix)

        if self.combination_rule == "naive":
            # Compute mean (averaged across models)
            ensemble_mean_f = torch.mean(torch.stack(means, dim=0), dim=0)

            # Compute covariance (averaged across models)
            ensemble_covar_f = torch.mean(torch.stack(covars), dim=0)

        return gpytorch.distributions.MultivariateNormal(ensemble_mean_f, ensemble_covar_f)

    def predict(self, dataloader):
        # Get predictions from each model
        predictions, means, covars, noises = [], [], [], []
        for i in range(self.n_estimators):
            model_i = self.models[i]
            model_i.eval()
            predict_output = model_i.predict(dataloader)
            predictions.append(predict_output['predictions'])
            means.append(predict_output['mean'])
            covars.append(predict_output['predictions'].covariance_matrix)
            noises.append(self.models[i].likelihood.raw_noise.item())

        ensemble_preds_f = None
        if self.combination_rule == "naive":
            # Compute likelihood w/ noise averaged across models
            ensemble_likelihood = gpytorch.likelihoods.GaussianLikelihood()
            ensemble_likelihood.raw_noise = torch.mean(torch.Tensor(noises))

            # Compute mean (averaged across models)
            ensemble_mean_f = torch.mean(torch.stack(means, dim=0), dim=0)

            # Compute covariance (averaged across models)
            ensemble_covar_f = torch.mean(torch.stack(covars), dim=0)

            # Get ensemble predictions
            # TODO: Is it correct to use the ensemble likelihood?
            ensemble_preds_f = ensemble_likelihood(
                gpytorch.distributions.MultivariateNormal(ensemble_mean_f, ensemble_covar_f))

        output = {'predictions': ensemble_preds_f,
                  'mean': ensemble_preds_f.mean,
                  'stddev': ensemble_preds_f.stddev.detach(),
                  'individual_preds': predictions
                  }
        return output

    def fit(self, train_data, args=None, debug=False, initialization=None):
        lst_of_models = []
        nmll_losses = []
        train_x, train_y = train_data
        for i in range(self.n_estimators):
            # Get subset of data but make sure not to split replicates. What if some point has been queried 100 times?
            #size_subset = int(len(train_x / args.repeat_sampling)*bagging_fraction)
            #idx = np.random.choice(np.arange(0, len(train_x)-1), size=size_subset, replace=replacement)
            # Get number of data points without repetition unless the datapoint has been further quired
            #n_datapoints = len(train_x) / args.repeat_sampling
            #size_subset = int(n_datapoints * bagging_fraction)
            #tmp_idx = np.random.choice(np.arange(0, n_datapoints - 1), size=size_subset, replace=replacement)
            # Convert indices to original array
            #tmp_idx = tmp_idx * args.repeat_sampling
            #idx = []
            #for i in range(args.repeat_sampling):
            #    idx = np.concatenate((idx, tmp_idx + 1), axis=0)
            idx = self.subset_idx[i]
            train_x_trans_subset = train_x[idx]
            train_y_trans_subset = train_y[idx]

            model_i = self.models[i]

            # Fit model
            nmll_loss, fit_losses = model_i.fit((train_x_trans_subset, train_y_trans_subset),
                                                args=args, initialization=i)
            nmll_losses.append(nmll_loss)

        nmll_loss = np.mean(nmll_losses)

        return nmll_loss, fit_losses
