import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import gpytorch
import torch

import pyro
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
import pyro.contrib.gp as gp

from sklearn.neighbors import KernelDensity

import models.exact_gp


class BaseFBGP():
    """
    Base class for an Fully Bayesian GP (FBGP).
    Contains:
    - Prediction
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on x.

        :param x: input data
        """

        # Extract samples from posterior
        posterior_samples = self.mcmc.get_samples()
        #posterior_samples = self.mcmc_samples

        if "covar_module.base_kernel.lengthscale_prior" in posterior_samples.keys():
            del posterior_samples["covar_module.base_kernel.lengthscale_prior"]
            del posterior_samples['likelihood.noise_prior']

        def get_mode_from_kde2d(data):
            kde = KernelDensity(kernel='gaussian', bandwidth=0.3)
            kde.fit(data)
            logprob = kde.score_samples(data)
            return data.iloc[np.argmax(logprob)].values.tolist()

        ######
        for l in range(posterior_samples['kernel.lengthscale'].shape[1]):
            key = 'lengthscale' + str(l)
            posterior_samples[key] = posterior_samples['kernel.lengthscale'][:, l]
        del posterior_samples['kernel.lengthscale']
        posterior_samples['noise0'] = posterior_samples['noise']
        del posterior_samples['noise']

        ######
        modes = get_mode_from_kde2d(pd.DataFrame(posterior_samples))
        #posterior_samples["covar_module.base_kernel.lengthscale_prior"] = \
        #    posterior_samples['kernel.lengthscale'].view(-1, 1, 1)
        #posterior_samples["likelihood.noise_prior"] = posterior_samples['noise'].view(-1, 1)

        # Make a ExactGPModel so we predict with that one
        likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
        pred_model = models.exact_gp.ExactGPModel(self.train_x, self.train_y,
                                                  self.gpytorch_kernel, self.gpytorch_likelihood)
        # NB: Make sure that gp.ExactGPModel has zero_mean and outputscale=1
        #pred_model.covar_module.base_kernel.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        pred_model.covar_module.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        pred_model.likelihood.register_prior("noise_prior", self.gpytorch_likelihood.noise_covar.noise_prior, "noise")

        #length_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=torch.sqrt(torch.tensor([3])).item())
        #noise_prior = gpytorch.priors.LogNormalPrior(loc=0, scale=torch.sqrt(torch.tensor([3])).item())
        #pred_model.covar_module.register_prior("lengthscale_prior", length_prior, "lengthscale")
        #pred_model.likelihood.register_prior("noise_prior", noise_prior, "noise")

        if self.args.predict_mcmc == "mode":
            #pred_model.covar_module.base_kernel.lengthscale = torch.tensor(mode_lengthscale)
            pred_model.covar_module.lengthscale = torch.tensor(modes[:-1])
            #pred_model.covar_module.kernels[0].lengthscale = torch.tensor(modes[:-1])[0]
            #pred_model.covar_module.kernels[1].lengthscale = torch.tensor(modes[:-1])[1]
            pred_model.likelihood.noise = torch.tensor(modes[-1])
            self.pred_model = pred_model
            output = pred_model.predict(dataloader)
        elif self.args.predict_mcmc == 'posterior':
            raise NotImplementedError
        elif self.args.predict_mcmc == 'moments':
            pred_model.covar_module.lengthscale = torch.tensor(modes[:-1])
            pred_model.likelihood.noise = torch.tensor(modes[-1])
            self.pred_model = pred_model
            if self.batch_model is None:
                self.set_batch_model()

            # Now, we will do predictions with batch_model
            with gpytorch.settings.max_cholesky_size(10000), gpytorch.settings.cg_tolerance(0.01):
                batches = False if type(dataloader) == tuple else True
                if batches:
                    raise NotImplementedError
                else:
                    x, _ = dataloader
                    with torch.no_grad(), gpytorch.settings.fast_pred_var():
                        n_mcmc_samples = self.args.num_chains * self.args.num_samples
                        expanded_x = x.repeat(n_mcmc_samples, 1, 1)
                        predictions_f = self.batch_model(expanded_x)
                        predictions_y = self.batch_model.likelihood(predictions_f)


            gmm_mean_gmm = torch.mean(predictions_y.mean, axis=1)
            gmm_covariance_matrix = 0

            output = {'predictions': predictions_y,
                      'mean': torch.mean(predictions_y.mean, axis=1),
                      'stddev': torch.mean(predictions_y.stddev, axis=1).detach(),
                      # 'stddev': predictions_f.stddev.detach()
                      }
            output = {'predictions': predictions_y,
                      'mean': predictions_y.mean,
                      'stddev': predictions_y.stddev.detach(),
                      # 'stddev': predictions_f.stddev.detach()
                      }
        else:
            raise NotImplementedError

        return output

    def loss_func(self):
        """
        Function that return the loss function.
        For a GP w/ exact inference the loss function is given by the ExactMarginalLogLikelihood
        """
        return None #gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)

    def fit(self, train_data, args=None, debug=False, initialization=None):
        """
        Function that fits (train) the model on the data (x,y).

        :param train_data: tuple with (features / input data, label / output data)
        :param debug:
        :param initialization:
        :param args: arguments
        """

        pyro.clear_param_store()
        #pyro.set_rng_seed(2)
        kernel = NUTS(self.gpr.model, target_accept_prob=0.8)
        self.mcmc = MCMC(kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps,
                         num_chains=args.num_chains, disable_progbar=False, mp_context="spawn")
        #torch.set_num_threads(1)
        self.mcmc.run()
        #self.mcmc.summary()
        self.mcmc_samples = self.mcmc.get_samples()
        #torch.set_num_threads(8)

        plt.hist(self.mcmc_samples['noise'].tolist(), bins=30)
        plt.savefig("output/test/crap_noise2.pdf", bbox_inches="tight")
        plt.close()
        #plt.hist(self.mcmc_samples['kernel.kern0.lengthscale'].tolist(), bins=30)
        #plt.savefig("output/test/crap_legnth02.pdf", bbox_inches="tight")
        #plt.close()
        #plt.hist(self.mcmc_samples['kernel.kern1.lengthscale'].tolist(), bins=30)
        #plt.savefig("output/test/crap_legnth12.pdf", bbox_inches="tight")
        #plt.close()

        final_loss, losses = -1, np.linspace(-1, 1, 10)
        return final_loss, losses, None

    def set_batch_model(self):
        # First rename posterior_samples
        posterior_samples = self.mcmc.get_samples()
        #posterior_samples["covar_module.base_kernel.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(
        #    -1, 1, 1)
        #posterior_samples["covar_module.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(
        #    -1, 1, 1)
        """
        if self.train_x.shape[1] == 2:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern0.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 3:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern0.kern0.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 6:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern0.kern0.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 1:
            posterior_samples["covar_module.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(-1, 1, 1)
        else:
            raise NotImplementedError
        """
        posterior_samples["covar_module.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(-1, 1,
                                                                                                           self.train_x.shape[
                                                                                                               1])
        posterior_samples["likelihood.noise_prior"] = posterior_samples['noise'].view(-1, 1)
        #del posterior_samples['kernel.lengthscale']
        #del posterior_samples['noise']

        # Turn model into batch model in order to draw mean functions
        likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
        pred_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=self.train_x.shape[1])
        batch_model = models.exact_gp.ExactGPModel(self.train_x, self.train_y, pred_kernel, likelihood)
        # Add priors to turn model into batch model afterwards
        #batch_model.covar_module.base_kernel.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        batch_model.covar_module.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        batch_model.likelihood.register_prior("noise_prior", self.noise_prior, "noise")
        batch_model.pyro_load_from_samples(posterior_samples)

        batch_model.eval()
        self.batch_model = batch_model
        self.batch_mll = gpytorch.mlls.ExactMarginalLogLikelihood(batch_model.likelihood, batch_model)
