# Base function for Gaussian Processes w/ approximate inference
import gpytorch
import torch
import numpy as np

from utils.optimizer import get_optimizer


class BaseApproximateGPModel(gpytorch.models.ApproximateGP):
    """
    Base class for an ApproximateGP.
    Contains:
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on x.

        :param x: input data
        """

        self.eval()
        batches = False if type(dataloader) == tuple else True
        if batches:
            raise NotImplementedError
        else:
            x, _ = dataloader
            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                predictions = self.likelihood(self(x))

        output = {'predictions': predictions,
                  'mean': predictions.mean,
                  'stddev': predictions.stddev.detach()}
        return output

    def loss_func(self):
        """
        Function that return the loss function.
        For a GP w/ approximate inference the loss function can be given by two different losses
        1) gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data)
        2) gpytorch.mlls.PredictiveLogLikelihood(model.likelihood, model, num_data)
        """
        return gpytorch.mlls.PredictiveLogLikelihood(self.likelihood, self, num_data=self.train_y.size(0))

    def fit(self, train_data, args=None, debug=False, initialization=None):
        """
        Function that fits (train) the model on the data (x,y).

        :param x: features / input data
        :param y: label / output data
        :param args: arguments
        """
        x, y = train_data

        def initialize_hyperparameters(model, iteration=None):
            """
            Initialize hyperparameters of a ApproximateGP

            """
            if iteration is None:
                return

            elif iteration == 0:
                print("ApproximateGP: Initialization in low noise regime")
                # Low noise, small lengthscale: prior for high signal-to-noise-ratio
                init_noise_prior = torch.Tensor([0.1])
                init_lengthscale_prior = torch.Tensor([0.1])
            elif iteration == 1:
                print("ApproximateGP: Initialization in high noise regime")
                # High noise, long lengthscale: prior for low signal-to-noise-ratio
                init_noise_prior = torch.Tensor([1.])
                init_lengthscale_prior = torch.Tensor([1.])
            else:
                raise NotImplementedError("Some is wrong with the initialization of the model.")

            opt_hypers = {
                'likelihood.noise_covar.noise': init_noise_prior,
                'covar_module.base_kernel.lengthscale': init_lengthscale_prior
            }
            model.initialize(**opt_hypers)

        # Settings
        n_runs = args.n_runs
        training_iter = args.n_epochs
        opt, scheduler = get_optimizer(args, self, num_data=y.size(0))
        optimizer = opt[0]
        ngd_optimizer = opt[1]

        # Put model into training mode
        self.train()

        # Fit the model
        min_loss = 10e6
        mll = self.loss_func()
        #mll = gpytorch.mlls.PredictiveLogLikelihood(self.likelihood, self, num_data=y.size(0))
        for run in range(n_runs):
            tmp_losses = []
            tmp_lr = []
            tmp_noises, tmp_lengthscales = [], []

            if n_runs > 0:
                initialize_hyperparameters(self, iteration=initialization)

            for i in range(training_iter):
                if debug:
                    opt_hypers = {
                        # 'likelihood.noise_covar.noise': torch.tensor(0.001)
                        'mean_module.constant': torch.tensor(0.57),  # 0.001
                        'covar_module.outputscale': torch.tensor(1.1),  # 1.
                    }
                    self.initialize(**opt_hypers)

                if y.shape[0] == 1:
                    tmp_lengthscales.append(self.covar_module.base_kernel.raw_lengthscale.item())
                    tmp_noises.append(self.likelihood.noise_covar.raw_noise.item())

                # Zero gradients from previous iteration
                if ngd_optimizer is not None:
                    ngd_optimizer.zero_grad()
                optimizer.zero_grad()
                output = self(x)
                loss = -mll(output, y)
                loss.backward()
                if ngd_optimizer is not None:
                    ngd_optimizer.step()
                optimizer.step()

                # Learning rate scheduler
                scheduler.step()

                tmp_losses.append(loss.item())

                """
                # Print parameters and loss
                print(f"Iteration {i}. \nLoss", losses[-1])
                for param_name, param in model.named_parameters():
                    print(f'Parameter name: {param_name:42} value = {param.item()}')
                """

                """
                # Handling numerical unstability issues in NGD. It probably required so save the best model ?
                # One could make a try-except, if model() fails.
                # But here, we do a crucial cut-off. That is cut-off the fitting, if the MLL has not improved by something
                if losses[i - 1] * 100 < losses[i]:
                    model.load_state_dict(torch.load('model_state_fitting.pth'))
                    break
                """

                if debug:
                    for param_group in optimizer.param_groups:
                        tmp_lr.append(param_group['lr'])

                if i > 11:
                    # Stop, if the loss doesn't change anymore
                    if tmp_losses[i] == tmp_losses[i - 10]:
                        break

            # Save the best model
            if tmp_losses[-1] < min_loss:
                min_loss = tmp_losses[-1]
                torch.save(self.state_dict(), f'best_state_dict_{args.model_type}_{args.k_samples}_{args.seed}.pth')
                losses = tmp_losses
                lr = tmp_lr
                noises = tmp_noises
                lengthscales = tmp_lengthscales

        self.load_state_dict(torch.load(f'best_state_dict_{args.model_type}_{args.k_samples}_{args.seed}.pth'))

        if debug:
            return losses[-1], losses, noises, lengthscales, lr

        return losses[-1], losses
