# Base function for Gaussian Processes w/ exact inference
import gpytorch
import torch
import numpy as np
from tqdm import tqdm_notebook

from utils.optimizer import get_optimizer


class BaseExactGPModel(gpytorch.models.ExactGP):
    """
    Base class for an ExactGP.
    Contains:
    - Prediction
    - Loss functions
    - Fitting procedure

    """

    def forward(self, x):
        """
        Function that feeds the data through the model.
        This function is dependent on the model.

        :param x: input data
        """
        raise NotImplementedError

    def predict(self, dataloader):
        """
        Function that predicts the label on x.

        :param x: input data
        """
        with gpytorch.settings.max_cholesky_size(10000), gpytorch.settings.cg_tolerance(0.01):
            self.eval()
            batches = False if type(dataloader) == tuple else True
            if batches:
                raise NotImplementedError
            else:
                x, _ = dataloader
                with torch.no_grad(), gpytorch.settings.fast_pred_var():
                    predictions_f = self(x)
                    predictions_y = self.likelihood(predictions_f)

        output = {'predictions': predictions_y,
                  'mean': predictions_y.mean,
                  'stddev': predictions_y.stddev.detach(),
                  #'stddev': predictions_f.stddev.detach()
                  }
        return output

    def loss_func(self):
        """
        Function that return the loss function.
        For a GP w/ exact inference the loss function is given by the ExactMarginalLogLikelihood
        """
        return gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)

    def fit(self, train_data, args=None, debug=False, initialization=None, test_data=None):
        """
        Function that fits (train) the model on the data (x,y).

        :param train_data: tuple with (features / input data, label / output data)
        :param debug:
        :param initialization:
        :param test_data:
        :param args: arguments
        """
        x, y = train_data

        def initialize_hyperparameters(model, iteration=None):
            """
            Initialize hyperparameters of a GP
            NB: Only works for ExactGP?
            """
            # Uniform priors
            # noise_prior = np.exp(np.linspace(np.log(0.01), np.log(1), 100)) # [0.01, 1] on log-scale
            if iteration is None:
                return
                #noise_prior = np.arange(0.01, 1, 0.01)
                #lengthscale_prior = np.arange(0.01, 1, 0.01)
                #init_noise_prior = torch.Tensor(noise_prior[np.random.choice(noise_prior.shape[0], 1)])
                #init_lengthscale_prior = torch.Tensor(
                #    lengthscale_prior[np.random.choice(lengthscale_prior.shape[0], 1)])

            elif iteration == 0:
                print("ExactGP: Initialization in low noise regime")
                # Low noise, small lengthscale: prior for high signal-to-noise-ratio
                init_noise_prior = torch.Tensor([0.1])
                init_lengthscale_prior = torch.Tensor([0.1])
            elif iteration == 1:
                print("ExactGP: Initialization in high noise regime")
                # High noise, long lengthscale: prior for low signal-to-noise-ratio
                init_noise_prior = torch.Tensor([1.])
                init_lengthscale_prior = torch.Tensor([1.])
            else:
                raise NotImplementedError("Some is wrong with the initialization of the model.")

            opt_hypers = {
                'likelihood.noise_covar.noise': init_noise_prior,
                'mean_module.constant': torch.tensor(0.),  # zero mean
                'covar_module.outputscale': torch.tensor(1.),  # unit variance
                'covar_module.base_kernel.lengthscale': init_lengthscale_prior
            }
            model.initialize(**opt_hypers)

        def mse(preds, targets):
            return torch.mean(torch.pow(preds - targets, 2), dim=0)

        def rrse(preds, targets):
            return torch.sqrt(mse(preds, targets) / mse(torch.mean(targets, dim=0), targets))

        # Settings
        n_runs = args.n_runs
        training_iter = args.n_epochs
        opt, scheduler = get_optimizer(args, self, num_data=y.size(0))
        optimizer = opt[0]
        ngd_optimizer = opt[1]

        # Put model into training mode
        self.train()

        # Fit the model
        min_loss = 10e10
        mll = self.loss_func()
        for run in range(n_runs):
            tmp_losses = []
            tmp_losses_valid, tmp_losses_valid_rmse, tmp_losses_valid2 = [], [], []
            tmp_lr = []
            tmp_noises, tmp_lengthscales, tmp_outputscales = [], [], []
            outputscales = torch.empty([training_iter, args.outputs])
            lengthscales = torch.empty([training_iter, args.outputs, x.shape[1]])
            means = torch.empty([training_iter, args.outputs])
            noises = torch.empty([training_iter, args.outputs])
            noises2 = torch.empty([training_iter])

            if n_runs > 0:
                initialize_hyperparameters(self, iteration=initialization)

            #for i in tqdm_notebook(range(training_iter)):
            for i in range(training_iter):
                # if debug:
                #     opt_hypers = {
                #         # 'likelihood.noise_covar.noise': torch.tensor(0.001)
                #         'mean_module.constant': torch.tensor(0.57),  # 0.001
                #         'covar_module.outputscale': torch.tensor(1.1),  # 1.
                #     }
                #     self.initialize(**opt_hypers)
                if debug: #y.shape[0] == 1:
                    """ Single output 
                    tmp_lengthscales.append(self.covar_module.base_kernel.raw_lengthscale.item())
                    tmp_noises.append(self.likelihood.noise_covar.raw_noise.item())
                    tmp_outputscales.append(self.covar_module.raw_outputscale.item())
                    """
                # Zero gradients from previous iteration
                if ngd_optimizer is not None:
                    ngd_optimizer.zero_grad()
                optimizer.zero_grad()
                with gpytorch.settings.max_cholesky_size(10000), gpytorch.settings.cg_tolerance(0.01):
                    output = self(x)
                    loss = -mll(output, y)
                    loss.backward()
                if ngd_optimizer is not None:
                    ngd_optimizer.step()
                optimizer.step()
                ## Debugging multiple outputs GP
                #outputscales[i, :] = self.covar_module.outputscale.detach()
                #lengthscales[i, :, :] = self.covar_module.base_kernel.lengthscale.detach().permute(1, 0, 2)
                #means[i, :] = self.mean_module.constant.detach().reshape(-1)
                #noises[i, :] = self.likelihood.noise_covar.noise.detach()
                #noises2[i] = self.likelihood.noise.detach()
                #noises[i, :] = self.likelihood.task_noises.detach()

                if test_data is not None:
                    self.eval()
                    x_test, y_test = test_data
                    with torch.no_grad(), gpytorch.settings.fast_pred_var():
                        #predictions = self.likelihood(self(x))
                        output_test = self(x_test)
                    loss_test = -mll(output_test, y_test)
                    loss_test2 = -mll(self.likelihood(output_test), y_test)
                    #rmse_loss_test = torch.sqrt(torch.mean(torch.pow(
                    #    self.likelihood(output_test).mean - y_test, 2), dim=0))
                    rmse_loss_test = rrse(self.likelihood(output_test).mean, y_test)
                    self.train()
                    tmp_losses_valid.append(loss_test.item())
                    tmp_losses_valid_rmse.append(rmse_loss_test.detach().numpy())
                    tmp_losses_valid2.append(loss_test2.detach().numpy())


                # Learning rate scheduler
                if args.model_type == 'indep_exact':
                    scheduler.step(loss)
                else:
                    scheduler.step(loss)

                tmp_losses.append(loss.item())

                """
                # Print parameters and loss
                print(f"Iteration {i}. \nLoss", losses[-1])
                for param_name, param in model.named_parameters():
                    print(f'Parameter name: {param_name:42} value = {param.item()}')
                """

                """
                # Handling numerical unstability issues in NGD. It probably required so save the best model ?
                # One could make a try-except, if model() fails.
                # But here, we do a crucial cut-off. That is cut-off the fitting, if the MLL has not improved by something
                if losses[i - 1] * 100 < losses[i]:
                    model.load_state_dict(torch.load('model_state_fitting.pth'))
                    break
                """

                if debug:
                    for param_group in optimizer.param_groups:
                        tmp_lr.append(param_group['lr'])

                """        
                if i > 11:
                    # Stop, if the loss doesn't change anymore
                    if tmp_losses[i] == tmp_losses[i - 10]:
                        break
                """

                # Save the best model
                if tmp_losses[-1] < min_loss:
                    min_loss = tmp_losses[-1]
                    #torch.save(self.state_dict(), f'best_state_dict_{args.model_type}_{args.simulator}_{args.seed}.pth')
                    losses = list(tmp_losses)

                    if debug:
                        lr = tmp_lr
                        noises = tmp_noises
                        lengthscales = tmp_lengthscales
                        outputscales = tmp_outputscales
                #print("Iteration", i, "loss", tmp_losses[-1])
        #self.load_state_dict(torch.load(f'best_state_dict_{args.model_type}_{args.simulator}_{args.seed}.pth'))

        if debug:
            return losses[-1], losses, noises, lengthscales, outputscales, lr

        outs = {'outputscales': outputscales,
                'lengthscales': lengthscales,
                'means': means,
                'noises': noises,
                'noises2': noises2}

        return losses[-1], losses, outs #, tmp_losses_valid, tmp_losses_valid_rmse, tmp_losses_valid2
