#!/usr/bin/env python3

import functools
from abc import ABC, abstractproperty
from copy import deepcopy
from typing import Optional, Tuple, Union

import torch
from linear_operator.operators import LinearOperator
from torch import Tensor

from .. import settings
from ..distributions import Delta, Distribution, MultivariateNormal
from ..kernels import Kernel
from ..likelihoods import GaussianLikelihood
from ..means import Mean
from ..models import ApproximateGP, ExactGP
from ..models.exact_prediction_strategies import DefaultPredictionStrategy
from ..module import Module
from ..utils.memoize import add_to_cache, cached, clear_cache_hook
from . import _VariationalDistribution


class _BaseExactGP(ExactGP):
    def __init__(
        self,
        train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
        train_targets: Optional[Tensor],
        likelihood: GaussianLikelihood,
        mean_module: Mean,
        covar_module: Kernel,
    ):
        super().__init__(train_inputs, train_targets, likelihood)
        self.mean_module = mean_module
        self.covar_module = covar_module

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return MultivariateNormal(mean, covar)


def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
    if tsr.grad_fn is not None:
        wrapper = functools.partial(clear_cache_hook, pred_strat)
        functools.update_wrapper(wrapper, clear_cache_hook)
        tsr.grad_fn.register_hook(wrapper)
    return tsr


class _VariationalStrategy(Module, ABC):
    """
    Abstract base class for all Variational Strategies.
    """

    has_fantasy_strategy = False

    def __init__(
        self,
        model: Union[ApproximateGP, "_VariationalStrategy"],
        inducing_points: Tensor,
        variational_distribution: _VariationalDistribution,
        learn_inducing_locations: bool = True,
        jitter_val: Optional[float] = None,
    ):
        super().__init__()

        self._jitter_val = jitter_val

        # Model
        object.__setattr__(self, "model", model)

        # Inducing points
        inducing_points = inducing_points.clone()
        if inducing_points.dim() == 1:
            inducing_points = inducing_points.unsqueeze(-1)
        if learn_inducing_locations:
            self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points))
        else:
            self.register_buffer("inducing_points", inducing_points)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Pre-processing step in __call__ to make x the same batch_shape as the inducing points
        """
        
        batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
        inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
        x = x.expand(*batch_shape, *x.shape[-2:])
        
        
        return x, inducing_points

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=self.inducing_points.dtype)
        return self._jitter_val

    @jitter_val.setter
    def jitter_val(self, jitter_val: float):
        self._jitter_val = jitter_val

    @abstractproperty
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
        GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly,
        this is done simply by calling the user defined GP prior on the inducing point data directly.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`p( \mathbf u)`
        """
        raise NotImplementedError

    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> Distribution:
        return self._variational_distribution()

    def forward(
        self,
        x: Tensor,
        inducing_points: Tensor,
        inducing_values: Tensor,
        variational_inducing_covar: Optional[LinearOperator] = None,
        **kwargs,
    ) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
        inducing point function values. Specifically, forward defines how to transform a variational distribution
        over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
        specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`

        :param x: Locations :math:`\mathbf X` to get the
            variational posterior of the function values at.
        :param inducing_points: Locations :math:`\mathbf Z` of the inducing points
        :param inducing_values: Samples of the inducing function values :math:`\mathbf u`
            (or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian.
        :param variational_inducing_covar: If
            the distribuiton :math:`q(\mathbf u)` is
            Gaussian, then this variable is the covariance matrix of that Gaussian.
            Otherwise, it will be None.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`q( \mathbf f(\mathbf X))`
        """
        raise NotImplementedError

    def kl_divergence(self) -> Tensor:
        r"""
        Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
        and the prior inducing distribution :math:`p(\mathbf u)`.
        """
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

    @cached(name="amortized_exact_gp")
    def amortized_exact_gp(
        self, mean_module: Optional[Module] = None, covar_module: Optional[Module] = None
    ) -> ExactGP:
        mean_module = self.model.mean_module if mean_module is None else mean_module
        covar_module = self.model.covar_module if covar_module is None else covar_module

        with torch.no_grad():
            # from here on down, we refer to the inducing points as pseudo_inputs
            pseudo_target_covar, pseudo_target_mean = self.pseudo_points
            pseudo_inputs = self.inducing_points.detach()
            if pseudo_inputs.ndim < pseudo_target_mean.ndim:
                pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape)
            # TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
            new_covar_module = deepcopy(covar_module)

            # update inducing mean if necessary
            pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs)

            inducing_exact_model = _BaseExactGP(
                pseudo_inputs,
                pseudo_target_mean,
                mean_module=deepcopy(mean_module),
                covar_module=new_covar_module,
                likelihood=deepcopy(self.model.likelihood),
            )

            # now fantasize around this model
            # as this model is new, we need to compute a posterior to construct the prediction strategy
            # which uses the likelihood pseudo caches
            faked_points = torch.randn(
                *pseudo_target_mean.shape[:-2],
                1,
                pseudo_inputs.shape[-1],
                device=pseudo_inputs.device,
                dtype=pseudo_inputs.dtype,
            )
            inducing_exact_model.eval()
            _ = inducing_exact_model(faked_points)

            # then we overwrite the likelihood to take into account the multivariate normal term
            pred_strat = inducing_exact_model.prediction_strategy
            pred_strat._memoize_cache = {}
            with torch.no_grad():
                updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar
                pred_strat.lik_train_train_covar = updated_lik_train_train_covar

            # do the mean cache because the mean cache doesn't solve against lik_train_train_covar
            train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
            train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
            mean_cache = updated_lik_train_train_covar.solve(train_labels_offset).squeeze(-1)
            mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
            add_to_cache(pred_strat, "mean_cache", mean_cache)
            # TODO: check to see if we need to do the covar_cache?

            inducing_exact_model.prediction_strategy = pred_strat
        return inducing_exact_model

    def pseudo_points(self) -> Tuple[Tensor, Tensor]:
        raise NotImplementedError("Each variational strategy must implement its own pseudo points method")

    def get_fantasy_model(
        self,
        inputs: Tensor,
        targets: Tensor,
        mean_module: Optional[Module] = None,
        covar_module: Optional[Module] = None,
        **kwargs,
    ) -> ExactGP:
        r"""
        Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
        an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
        points and targets.

        Currently, instead of directly updating the variational parameters (and inducing points), we instead
        return an ExactGP model rather than an updated variational GP model. This is done primarily for
        numerical stability.

        Unlike the ExactGP's call for get_fantasy_model, we enable options for mean_module and covar_module
        that allow specification of the mean / covariance. We expect that either the mean and covariance
        modules are attributes of the model itself called mean_module and covar_module respectively OR that you
        pass them into this method explicitly.

        :param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
            observations.
        :param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
        :param mean_module: torch module describing the mean function of the GP model. Optional if
            `mean_module` is already an attribute of the variational GP.
        :param covar_module: torch module describing the covariance function of the GP model. Optional
            if `covar_module` is already an attribute of the variational GP.
        :return: An `ExactGP` model with `k + m` training examples, where the `m` fantasy examples have been added
            and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
            GP. Note that we return an `ExactGP` rather than a variational GP.

        Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
            Maddox, Stanton, Wilson, NeurIPS, '21
            https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
        """

        # currently, we only support fantasization for CholeskyVariationalDistribution and
        # whitened / unwhitened variational strategies
        if not self.has_fantasy_strategy:
            raise NotImplementedError(
                "No fantasy model support for ",
                self.__name__,
                ". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.",
            )
        if not isinstance(self.model.likelihood, GaussianLikelihood):
            raise NotImplementedError(
                "No fantasy model support for ",
                self.model.likelihood,
                ". Only GaussianLikelihoods are currently supported.",
            )
        # we assume that either the user has given the model a mean_module and a covar_module
        # or that it will be passed into the get_fantasy_model function. we check for these.
        if mean_module is None:
            mean_module = getattr(self.model, "mean_module", None)
            if mean_module is None:
                raise ModuleNotFoundError(
                    "Either you must provide a mean_module as input to get_fantasy_model",
                    "or it must be an attribute of the model called mean_module.",
                )
        if covar_module is None:
            covar_module = getattr(self.model, "covar_module", None)
            if covar_module is None:
                # raise an error
                raise ModuleNotFoundError(
                    "Either you must provide a covar_module as input to get_fantasy_model",
                    "or it must be an attribute of the model called covar_module.",
                )

        # first we construct an exact model over the inducing points with the inducing covariance
        # matrix
        inducing_exact_model = self.amortized_exact_gp(mean_module=mean_module, covar_module=covar_module)

        # then we update this model by adding in the inputs and pseudo targets
        # finally we fantasize wrt targets
        fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
        fant_pred_strat = fantasy_model.prediction_strategy

        # first we update the lik_train_train_covar
        # do the mean cache again because the mean cache resets the likelihood forward
        train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
        train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
        fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
        mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
        mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
        add_to_cache(fant_pred_strat, "mean_cache", mean_cache)
        # TODO: should we update the covar_cache?

        fantasy_model.prediction_strategy = fant_pred_strat
        return fantasy_model

    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()
        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        # Ensure inducing_points and x are the same size
        inducing_points = self.inducing_points
        if inducing_points.shape[:-2] != x.shape[:-2]:
            x, inducing_points = self._expand_inputs(x, inducing_points)

        # Get p(u)/q(u)
        variational_dist_u = self.variational_distribution

        # Get q(f)
        if isinstance(variational_dist_u, MultivariateNormal):
            return super().__call__(
                x,
                inducing_points,
                inducing_values=variational_dist_u.mean,
                variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
                **kwargs,
            )
        elif isinstance(variational_dist_u, Delta):
            return super().__call__(
                x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None, **kwargs
            )
        else:
            raise RuntimeError(
                f"Invalid variational distribuition ({type(variational_dist_u)}). "
                "Expected a multivariate normal or a delta distribution."
            )
