# Deep Gaussian Processes
import gpytorch
import tqdm
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from gpytorch.mlls import DeepApproximateMLL
from gpytorch.mlls import VariationalELBO, PredictiveLogLikelihood

from utils.optimizer import get_optimizer


class BaseDeepGP(gpytorch.models.deep_gps.DeepGP):
    """
    Base class for a DeepGP.
    Contains:
    - Prediction
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on dataloader (x,y).

        :param x: input data
        """
        self.eval()
        batches = False if type(dataloader) == tuple else True

        if batches:
            with torch.no_grad(), gpytorch.settings.num_likelihood_samples(self.n_samples):
                mus = []
                stddevs = []
                predictions = []
                lls = []
                for x_batch, y_batch in dataloader:
                    # x_batch = x_batch.to(device)
                    # y_batch = y_batch.to(device)
                    preds = self.likelihood(self(x_batch))#.to_data_independent_dist() # Should indep_dist be applied for DeepGP?
                    mus.append(preds.mean)
                    stddevs.append(preds.stddev)
                    predictions.append(preds)
                    lls.append(self.likelihood.log_marginal(y_batch, self(x_batch)))
            mean = torch.cat(mus, dim=-1)
            stddev = torch.cat(stddevs, dim=-1)
        else:
            x, _ = dataloader
            with torch.no_grad(), gpytorch.settings.num_likelihood_samples(self.n_samples):
                predictions = self.likelihood(self(x))
            if len(x.shape) > 1:
                predictions = predictions.to_data_independent_dist()
            mean = predictions.mean.mean(0)
            stddev = predictions.stddev.mean(0).detach()

        output = {"predictions": predictions,
                  "mean": mean,
                  "stddev": stddev}
        return output

    def loss_func(self):
        """
        Function that return the loss function.
        Like for an ApproximateGP, where are two different losses for a DeepGP:
        1) DeepApproximateMLL(VariationalELBO(model.likelihood, model, num_data))
        2) DeepApproximateMLL(PredictiveLogLikelihood(model.likelihood, model, num_data))
        """

        return DeepApproximateMLL(PredictiveLogLikelihood(self.likelihood, self, self.train_x_shape[0]))

    def fit(self, dataloader, args=None, debug=False):
        """
        Function that fits (train) the model on the data (x,y).

        :param x: features / input data
        :param y: label / output data
        :param args: arguments
        """
        if type(dataloader) == tuple:
            train_x, train_y = dataloader
            gp_train_dataset = TensorDataset(train_x, train_y)
            dataloader = DataLoader(gp_train_dataset, batch_size=len(train_x), shuffle=False)


        num_epochs = args.n_epochs

        opt, scheduler = get_optimizer(args, self, num_data=self.train_x_shape[0])
        optimizer = opt[0]
        ngd_optimizer = opt[1]

        self.train()
        #model.likelihood.train()
        losses = []
        # epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
        mll = self.loss_func()
        epochs_iter = tqdm.tqdm(range(num_epochs), desc="Epoch")
        for i in epochs_iter:
            # Within each iteration, we will go over each minibatch of data
            # minibatch_iter = tqdm.notebook.tqdm(gp_train_loader, desc="Minibatch", leave=False)
            for x_batch, y_batch in dataloader:
                with gpytorch.settings.num_likelihood_samples(self.n_samples):
                    if ngd_optimizer is not None:
                        ngd_optimizer.zero_grad()
                    optimizer.zero_grad()
                    output = self(x_batch)
                    loss = -mll(output, y_batch)
                    loss.backward()
                    if ngd_optimizer is not None:
                        ngd_optimizer.step()
                    optimizer.step()

                    # minibatch_iter.set_postfix(loss=loss.item())
                losses.append(loss.item())

            scheduler.step(loss)
            epochs_iter.set_postfix(loss=loss.item())

            # Stop, if the loss doesn't change anymore
            #if i > 11:
            #    if np.round(losses[i], 3) == np.round(losses[i - 10], 3):
            #        break

        return loss.item(), losses, None
