import os
import math
import torch

import gpytorch
from matplotlib import pyplot as plt


class GP(gpytorch.models.ExactGP):
    def __init__(
        self,
        train_x,
        train_y,
        initialization=None,
        prior=None,
        ard=True,
        fix_mean_at=None,
    ):
        """__init__.

        Parameters
        ----------
        train_x : tensor of shape (n, x_dim)
            training inputs
        train_y : tensor of shape (n,)
            training observations

        initialization : dictionary
            initial values of GP hyperparameters

            For example,
            {
                "likelihood.noise_covar.noise": 0.01,
                "covar_module.base_kernel.lengthscale": [0.1] * x_dim, # ard = True
                "covar_module.outputscale": 1.0,
                "mean_module.constant": 0.0,
            }

        prior : dictionary

            For example,
            {
                "lengthscale": gpytorch.priors.GammaPrior(0.25, 0.5),
                "outputscale": gpytorch.priors.GammaPrior(2.0, 0.15),
                "noise_std": gpytorch.priors.NormalPrior(0.0, 0.1),
            }

        ard : boolean
            True if ARD kernel is used
            False otherwise
        fix_mean_at : boolean
            True if the GP mean is a constant (not updated during training hyperparameters)
            False otherwise

        """
        likelihood = gpytorch.likelihoods.GaussianLikelihood(
            noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
        )

        super(GP, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        input_dim = train_x.shape[1]
        assert input_dim > 0

        ard_num_dims = input_dim if ard else None

        if prior is None:
            self.covar_module = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel(ard_num_dims=ard_num_dims)
            )
        else:
            self.likelihood.noise_covar.register_prior(
                "noise_std_prior",
                prior["noise_std"],
                lambda module: module.noise.sqrt(),
            )

            self.covar_module = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel(
                    ard_num_dims=ard_num_dims, lengthscale_prior=prior["lengthscale"]
                ),
                outputscale_prior=prior["outputscale"],
            )
            self.covar_module.base_kernel.register_constraint(
                "raw_lengthscale", gpytorch.constraints.GreaterThan(5e-2)
            )

        if initialization is None:
            if ard:
                if prior:
                    if (
                        len(prior["lengthscale"].mean.shape) == 0
                        or prior["lengthscale"].mean.squeeze().shape[0] == 1
                    ):
                        init_lengthscale = torch.squeeze(
                            prior["lengthscale"].mean
                        ) * torch.ones(1, input_dim)
                    else:
                        init_lengthscale = prior["lengthscale"].mean
                else:
                    init_lengthscale = torch.ones(1, ard_num_dims)

                self.covar_module.base_kernel.lengthscale = init_lengthscale
            else:
                self.covar_module.base_kernel.lengthscale = (
                    prior["lengthscale"].mean if prior else 1.0
                )

            self.covar_module.outputscale = (
                prior["outputscale"].mean if prior is not None else 1.0
            )
            self.likelihood.noise_covar.noise = 0.001

            self.mean_module.constant = 0.0

            if fix_mean_at is not None:
                self.mean_module.constant = fix_mean_at
                self.mean_module.constant.requires_grad = False

        else:
            self.initialize(**initialization)

        print("All constraints:")
        for constraint_name, constraint in self.named_constraints():
            print(f"Constraint name: {constraint_name:55} constraint = {constraint}")

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    @staticmethod
    def get_default_hyperparameter_prior(function_name="branin"):
        """get_default_hyperparameter_prior.

        Parameters
        ----------
        function_name : string
            the name of the objective function function

        Returns
        -------
        dictionary
            E.g.,
            {
                "lengthscale": gpytorch.priors.GammaPrior(0.25, 0.5),
                "outputscale": gpytorch.priors.GammaPrior(2.0, 0.15),
                "noise_std": gpytorch.priors.NormalPrior(0.0, 0.1),
            }

        """
        if function_name == "branin":
            return {
                "lengthscale": gpytorch.priors.GammaPrior(0.25, 0.5),
                "outputscale": gpytorch.priors.GammaPrior(2.0, 0.15),
                "noise_std": gpytorch.priors.NormalPrior(0.0, 0.1),
            }

        if function_name == "goldstein":
            return {
                "lengthscale": gpytorch.priors.GammaPrior(0.25, 0.5),
                "outputscale": gpytorch.priors.GammaPrior(2.0, 0.15),
                "noise_std": gpytorch.priors.NormalPrior(0.0, 0.1),
            }

        raise Exception(f"Unknown default GP hyperparameter prior for {function_name}")

    def save(self, path="model_state.pth"):
        """save.

        Parameters
        ----------
        path : string
            save the GP model to path
        """
        torch.save(self.state_dict(), path)

    def plot1d(self, ax, x):
        assert x.shape[1] == 1

        f_preds = GP.predict_f(self, x)

        with torch.no_grad():
            f_means = f_preds.mean
            f_vars = f_preds.variance
            f_stds = torch.sqrt(f_vars)

        ax.fill_between(
            x.squeeze(),
            f_means - f_stds,
            f_means + f_stds,
            alpha=0.5,
        )
        ax.plot(x.squeeze(), f_means)

        ax.scatter(self.train_inputs[0].squeeze(), self.train_targets.squeeze())
        return ax

    @staticmethod
    def load(model, path="model_state.pth"):
        """load model from path to model

        Parameters
        ----------
        model : GP object
            GP model without loaded data or hyperparameters
        path : string
            path to the saved model
        """
        state_dict = torch.load(path)
        model.load_state_dict(state_dict)

    @staticmethod
    def optimize_hyperparameters(
        model, train_x, train_y, learning_rate=0.1, training_iter=50, verbose=True
    ):
        """optimize_hyperparameters.

        Parameters
        ----------
        model : GP instance
            GP model to optimize hyperparameters
        train_x : tensor of size (n, x_dim)
            training inputs
        train_y : tensor of size (n,)
            training outputs
        learning_rate : float
            learning_rate
        training_iter : int
            number of training iterations
        verbose : bool
            True if printing messages of training procedure
            False if not printing anything
        """
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

        for i in range(training_iter):
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            loss.backward()
            if verbose:
                print(
                    f"Iter {i+1}/{training_iter} - Loss: {loss.item():.3f} lengthscale: {model.covar_module.base_kernel.lengthscale}  noise: {model.likelihood.noise}"
                )
            optimizer.step()

    @staticmethod
    def predict_f(model, test_x):
        """predict_f.

        Parameters
        ----------
        model : GP
            GP model to predict
        test_x : tensor of size (n, x_dim)
            testing inputs to predict

        Returns
        -------
        torch normal distribution
            predictive distribution of f(x) at test_x
        """
        with torch.no_grad():
            model.eval()
            return model(test_x)

    @staticmethod
    def predict_y(model, test_x):
        """predict_y.

        Parameters
        ----------
        model : GP
            GP model to predict
        test_x : tensor of size (n, x_dim)
            testing inputs to predict

        Returns
        -------
        torch normal distribution
            predictive distribution of y(x) at test_x
        """
        with torch.no_grad():
            model.eval()
            return model.likelihood(model(test_x))
