import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import gpytorch
import torch

import pyro
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
import pyro.contrib.gp as gp

from sklearn.neighbors import KernelDensity

import models.exact_gp


class BaseFBGP_gpytorch(gpytorch.models.ExactGP):
    """
    Base class for an Fully Bayesian GP (FBGP).
    Contains:
    - Prediction
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on x.

        :param x: input data
        """

        # Extract samples from posterior
        posterior_samples = self.mcmc.get_samples()
        #posterior_samples = self.mcmc_samples

        def get_mode_from_kde2d(data):
            kde = KernelDensity(kernel='gaussian', bandwidth=0.3)
            kde.fit(data)
            logprob = kde.score_samples(data)
            return data.iloc[np.argmax(logprob)].values.tolist()

        for l in range(posterior_samples['covar_module.lengthscale_prior'].shape[2]):
            key = 'lengthscale' + str(l)
            posterior_samples[key] = posterior_samples['covar_module.lengthscale_prior'][:, :, l].reshape(-1,)
        posterior_samples['noise'] = posterior_samples['likelihood.noise_covar.noise_prior'].reshape(-1,)
        del posterior_samples['likelihood.noise_covar.noise_prior']
        del posterior_samples['covar_module.lengthscale_prior']
        modes = get_mode_from_kde2d(pd.DataFrame(posterior_samples))

        # Make a ExactGPModel so we predict with that one
        likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
        pred_model = models.exact_gp.ExactGPModel(self.train_x, self.train_y,
                                                  self.covar_module, self.likelihood)
        # NB: Make sure that gp.ExactGPModel has zero_mean and outputscale=1
        #pred_model.covar_module.base_kernel.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        pred_model.covar_module.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        pred_model.likelihood.register_prior("noise_prior", self.likelihood.noise_covar.noise_prior, "noise")

        if self.args.predict_mcmc == "mode":
            #pred_model.covar_module.base_kernel.lengthscale = torch.tensor(mode_lengthscale)
            pred_model.covar_module.lengthscale = torch.tensor(modes[:-1])
            pred_model.likelihood.noise = torch.tensor(modes[-1])
            self.pred_model = pred_model
            output = pred_model.predict(dataloader)
        elif self.args.predict_mcmc == 'posterior':
            raise NotImplementedError
        elif self.args.predict_mcmc == 'moments':
            raise NotImplementedError
        else:
            raise NotImplementedError

        return output

    def loss_func(self):
        """
        Function that return the loss function.
        For a GP w/ exact inference the loss function is given by the ExactMarginalLogLikelihood
        """
        return None #gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)

    def pyro_model(self, x, y):
        with gpytorch.settings.fast_computations(False, False, False):
            sampled_model = self.pyro_sample_from_prior()
            output = sampled_model.likelihood(sampled_model(x))
            pyro.sample("obs", output, obs=y)
        return y

    def fit(self, train_data, args=None, debug=False, initialization=None):
        """
        Function that fits (train) the model on the data (x,y).

        :param train_data: tuple with (features / input data, label / output data)
        :param debug:
        :param initialization:
        :param args: arguments
        """

        pyro.clear_param_store()
        #pyro.set_rng_seed(2)
        kernel = NUTS(self.pyro_model, target_accept_prob=0.8)
        self.mcmc = MCMC(kernel, num_samples=args.num_samples, warmup_steps=args.warmup_steps,
                         num_chains=args.num_chains, disable_progbar=False, mp_context="spawn")
        #torch.set_num_threads(1)
        batches = False if type(train_data) == tuple else True
        if batches:
            raise NotImplementedError
        else:
            x, y = train_data
            self.mcmc.run(x, y)

        self.mcmc.summary()
        self.mcmc_samples = self.mcmc.get_samples()
        #torch.set_num_threads(8)

        plt.hist(self.mcmc_samples['likelihood.noise_covar.noise_prior'].tolist(), bins=30)
        plt.savefig("output/test/crap_noise2.pdf", bbox_inches="tight")
        plt.close()
        #plt.hist(self.mcmc_samples['kernel.kern0.lengthscale'].tolist(), bins=30)
        #plt.savefig("output/test/crap_legnth02.pdf", bbox_inches="tight")
        #plt.close()
        #plt.hist(self.mcmc_samples['kernel.kern1.lengthscale'].tolist(), bins=30)
        #plt.savefig("output/test/crap_legnth12.pdf", bbox_inches="tight")
        #plt.close()

        final_loss, losses = -1, np.linspace(-1, 1, 10)
        return final_loss, losses, None

    def set_batch_model(self):
        # First rename posterior_samples
        posterior_samples = self.mcmc.get_samples()
        #posterior_samples["covar_module.base_kernel.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(
        #    -1, 1, 1)
        #posterior_samples["covar_module.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(
        #    -1, 1, 1)
        if self.train_x.shape[1] == 2:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern0.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 3:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern0.kern0.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 6:
            posterior_samples["covar_module.lengthscale_prior"] = torch.cat([
                posterior_samples['kernel.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern0.kern1.lengthscale'].view(-1, 1, 1),
                posterior_samples['kernel.kern0.kern0.kern0.kern0.kern0.lengthscale'].view(-1, 1, 1)
            ], dim=2)
        elif self.train_x.shape[1] == 1:
            posterior_samples["covar_module.lengthscale_prior"] = posterior_samples['kernel.lengthscale'].view(-1, 1, 1)
        else:
            raise NotImplementedError

        posterior_samples["likelihood.noise_prior"] = posterior_samples['noise'].view(-1, 1)
        #del posterior_samples['kernel.lengthscale']
        #del posterior_samples['noise']

        # Turn model into batch model in order to draw mean functions
        likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
        pred_kernel = gpytorch.kernels.RBFKernel(ard_num_dims=self.train_x.shape[1])
        batch_model = models.exact_gp.ExactGPModel(self.train_x, self.train_y, pred_kernel, likelihood)
        # Add priors to turn model into batch model afterwards
        #batch_model.covar_module.base_kernel.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        batch_model.covar_module.register_prior("lengthscale_prior", self.length_prior, "lengthscale")
        batch_model.likelihood.register_prior("noise_prior", self.noise_prior, "noise")
        batch_model.pyro_load_from_samples(posterior_samples)

        batch_model.eval()
        self.batch_model = batch_model
        self.batch_mll = gpytorch.mlls.ExactMarginalLogLikelihood(batch_model.likelihood, batch_model)
