# Deep Gaussian Processes
import gpytorch
import tqdm
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from gpytorch.mlls import DeepPredictiveLogLikelihood

from utils.optimizer import get_optimizer


class BaseDSSP(gpytorch.models.deep_gps.dspp.DSPP):
    """
    Base class for a DeepGP.
    Contains:
    - Prediction
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on dataloader (x,y).

        :param dataloader: input data
        """
        self.eval()
        batches = False if type(dataloader) == tuple else True

        if batches:
            with gpytorch.settings.fast_computations(log_prob=False, solves=False), torch.no_grad():
                mus, variances, lls = [], [], []
                for x_batch, y_batch in dataloader:
                    predictions = self.likelihood(self(x_batch, mean_input=x_batch))
                    mus.append(predictions.mean.cpu())
                    variances.append(predictions.variance.cpu())

                    # Compute test log probability. The output of a DSPP is a weighted mixture of Q Gaussians,
                    # with the Q weights specified by self.quad_weight_grid. The below code computes the log probability of each
                    # test point under this mixture.

                    # Step 1: Get log marginal for each Gaussian in the output mixture.
                    # base_batch_ll = self.likelihood.log_marginal(y_batch, self(x_batch))

                    # Step 2: Weight each log marginal by its quadrature weight in log space.
                    # deep_batch_ll = self.quad_weights.unsqueeze(-1) + base_batch_ll

                    # Step 3: Take logsumexp over the mixture dimension, getting test log prob for each datapoint in the batch.
                    # batch_log_prob = deep_batch_ll.logsumexp(dim=0)
                    # lls.append(batch_log_prob.cpu())

            # torch.cat(lls, dim=-1)
            mean = torch.cat(mus, dim=-1)
            stddev = torch.cat(variances, dim=-1)
        else:
            x, _ = dataloader
            with torch.no_grad(), gpytorch.settings.num_likelihood_samples(self.n_samples):
                predictions = self.likelihood(self(x, mean_input=x))
            means = predictions.mean
            stddevs = predictions.stddev.detach()
            # `means` currently contains the predictive output from each Gaussian in the mixture.
            # To get the total mean output, we take a weighted sum of these means over the quadrature weights.
            weights = self.quad_weights.unsqueeze(-1).exp().cpu().detach()
            mean = (weights * means).sum(0)
            stddev = (weights * stddevs).sum(0)

        output = {"predictions": predictions,
                  "mean": mean,
                  "stddev": stddev}
        return output

    def loss_func(self):
        """
        Function that return the loss function.
        Here, we use DeepPredictiveLogLikelihood (apparently different from the DeepGP)
        1) DeepApproximateMLL(VariationalELBO(model.likelihood, model, num_data))
        2) DeepApproximateMLL(PredictiveLogLikelihood(model.likelihood, model, num_data))
        """
        return DeepPredictiveLogLikelihood(self.likelihood, self, num_data=self.train_x_shape[0], beta=1)

    def fit(self, dataloader, args=None, debug=False):
        """
        Function that fits (train) the model on the data (x,y).

        :param x: features / input data
        :param y: label / output data
        :param args: arguments
        """
        if type(dataloader) == tuple:
            train_x, train_y = dataloader
            gp_train_dataset = TensorDataset(train_x, train_y)
            dataloader = DataLoader(gp_train_dataset, batch_size=len(train_x), shuffle=False)

        num_epochs = args.n_epochs

        opt, scheduler = get_optimizer(args, self, num_data=self.train_x_shape[0])
        optimizer = opt[0]
        ngd_optimizer = opt[1]

        self.train()
        losses = []
        mll = self.loss_func()
        epochs_iter = tqdm.tqdm(range(num_epochs), desc="Epoch")
        for i in epochs_iter:
            # Within each iteration, we will go over each minibatch of data
            # minibatch_iter = tqdm.notebook.tqdm(gp_train_loader, desc="Minibatch", leave=False)
            for x_batch, y_batch in dataloader:
                with gpytorch.settings.num_likelihood_samples(self.n_samples):
                    if ngd_optimizer is not None:
                        ngd_optimizer.zero_grad()
                    optimizer.zero_grad()
                    output = self(x_batch)
                    loss = -mll(output, y_batch)
                    loss.backward()
                    if ngd_optimizer is not None:
                        ngd_optimizer.step()
                    optimizer.step()

                    # minibatch_iter.set_postfix(loss=loss.item())
                losses.append(loss.item())

            scheduler.step(loss)
            epochs_iter.set_postfix(loss=loss.item())

            # Stop, if the loss doesn't change anymore
            #if i > 11:
            #    if np.round(losses[i], 3) == np.round(losses[i - 10], 3):
            #        break

        return loss.item(), losses, None
