#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from linear_operator.operators import LinearOperator, MatmulLinearOperator, SumLinearOperator
from torch import Tensor
from torch.distributions.kl import kl_divergence

from ..distributions import Delta, MultivariateNormal
from ..models import ApproximateGP
from ..utils.errors import CachingError
from ..utils.memoize import pop_from_cache_ignore_args
from ._variational_distribution import _VariationalDistribution
from .delta_variational_distribution import DeltaVariationalDistribution
from .variational_strategy import VariationalStrategy


class BatchDecoupledVariationalStrategy(VariationalStrategy):
    r"""
    A VariationalStrategy that uses a different set of inducing points for the
    variational mean and variational covar.  It follows the "decoupled" model
    proposed by `Jankowiak et al. (2020)`_ (which is roughly based on the strategies
    proposed by `Cheng et al. (2017)`_.

    Let :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma` be the mean/variance
    inducing points. The variational distribution for an input :math:`\mathbf
    x` is given by:

    .. math::

        \begin{align*}
            \mathbb E[ f(\mathbf x) ] &= \mathbf k_{\mathbf Z_\mu \mathbf x}^\top
            \mathbf K_{\mathbf Z_\mu \mathbf Z_\mu}^{-1} \mathbf m
            \\
            \text{Var}[ f(\mathbf x) ] &= k_{\mathbf x \mathbf x} - \mathbf k_{\mathbf Z_\sigma \mathbf x}^\top
            \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
            \left( \mathbf K_{\mathbf Z_\sigma} - \mathbf S \right)
            \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
            \mathbf k_{\mathbf Z_\sigma \mathbf x}
        \end{align*}

    where :math:`\mathbf m` and :math:`\mathbf S` are the variational parameters.
    Unlike the original proposed implementation, :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma`
    have **the same number of inducing points**, which allows us to perform batched operations.

    Additionally, you can use a different set of kernel hyperparameters for the mean and the variance function.
    We recommend using this feature only with the :obj:`~gpytorch.mlls.PredictiveLogLikelihood` objective function
    as proposed in "Parametric Gaussian Process Regressors" (`Jankowiak et al. (2020)`_).
    Use the mean_var_batch_dim to indicate which batch dimension corresponds to the different mean/var
    kernels.

    .. note::
        We recommend using the "right-most" batch dimension (i.e. ``mean_var_batch_dim=-1``) for the dimension
        that corresponds to the different mean/variance kernel parameters.

        Assuming you want `b1` many independent GPs, the :obj:`~gpytorch.variational._VariationalDistribution`
        objects should have a batch shape of `b1`, and the mean/covar modules
        of the GP should have a batch shape of `b1 x 2`.
        (The 2 corresponds to the mean/variance hyperparameters.)

    .. seealso::
        :obj:`~gpytorch.variational.OrthogonallyDecoupledVariationalStrategy` (a variant proposed by
        `Salimbeni et al. (2018)`_ that uses orthogonal projections.)

    :param model: Model this strategy is applied to.
        Typically passed in when the VariationalStrategy is created in the
        __init__ method of the user defined model.
    :param inducing_points: Tensor containing a set of inducing
        points to use for variational inference.
    :param variational_distribution: A
        VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
    :param learn_inducing_locations: (Default True): Whether or not
        the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
        parameters of the model).
    :param mean_var_batch_dim: (Default `None`):
        Set this parameter (ideally to `-1`) to indicate which dimension corresponds to different
        kernel hyperparameters for the mean/variance functions.
    :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability

    .. _Cheng et al. (2017):
        https://arxiv.org/abs/1711.10127

    .. _Salimbeni et al. (2018):
        https://arxiv.org/abs/1809.08820

    .. _Jankowiak et al. (2020):
        https://arxiv.org/abs/1910.07123

    Example (**different** hypers for mean/variance):
        >>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP):
        >>>     '''
        >>>     A batch of 3 independent MeanFieldDecoupled PPGPR models.
        >>>     '''
        >>>     def __init__(self, inducing_points):
        >>>         # The variational parameters have a batch_shape of [3]
        >>>         variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
        >>>             inducing_points.size(-1), batch_shape=torch.Size([3]),
        >>>         )
        >>>         variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy(
        >>>             self, inducing_points, variational_distribution, learn_inducing_locations=True,
        >>>             mean_var_batch_dim=-1
        >>>         )
        >>>
        >>>         # The mean/covar modules have a batch_shape of [3, 2]
        >>>         # where the last batch dim corresponds to the mean & variance hyperparameters
        >>>         super().__init__(variational_strategy)
        >>>         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 2]))
        >>>         self.covar_module = gpytorch.kernels.ScaleKernel(
        >>>             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 2])),
        >>>             batch_shape=torch.Size([3, 2]),
        >>>         )

    Example (**shared** hypers for mean/variance):
        >>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP):
        >>>     '''
        >>>     A batch of 3 independent MeanFieldDecoupled PPGPR models.
        >>>     '''
        >>>     def __init__(self, inducing_points):
        >>>         # The variational parameters have a batch_shape of [3]
        >>>         variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
        >>>             inducing_points.size(-1), batch_shape=torch.Size([3]),
        >>>         )
        >>>         variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy(
        >>>             self, inducing_points, variational_distribution, learn_inducing_locations=True,
        >>>         )
        >>>
        >>>         # The mean/covar modules have a batch_shape of [3, 1]
        >>>         # where the singleton dimension corresponds to the shared mean/variance hyperparameters
        >>>         super().__init__(variational_strategy)
        >>>         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 1]))
        >>>         self.covar_module = gpytorch.kernels.ScaleKernel(
        >>>             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 1])),
        >>>             batch_shape=torch.Size([3, 1]),
        >>>         )
    """

    def __init__(
        self,
        model: ApproximateGP,
        inducing_points: Tensor,
        variational_distribution: _VariationalDistribution,
        learn_inducing_locations: bool = True,
        mean_var_batch_dim: Optional[int] = None,
        jitter_val: Optional[float] = None,
    ):
        if isinstance(variational_distribution, DeltaVariationalDistribution):
            raise NotImplementedError(
                "BatchDecoupledVariationalStrategy does not work with DeltaVariationalDistribution"
            )

        if mean_var_batch_dim is not None and mean_var_batch_dim >= 0:
            raise ValueError(f"mean_var_batch_dim should be negative indexed, got {mean_var_batch_dim}")
        self.mean_var_batch_dim = mean_var_batch_dim

        # Maybe unsqueeze inducing points
        if inducing_points.dim() == 1:
            inducing_points = inducing_points.unsqueeze(-1)

        # We're going to create two set of inducing points
        # One set for computing the mean, one set for computing the variance
        if self.mean_var_batch_dim is not None:
            inducing_points = torch.stack([inducing_points, inducing_points], dim=(self.mean_var_batch_dim - 2))
        else:
            inducing_points = torch.stack([inducing_points, inducing_points], dim=-3)
        super().__init__(
            model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
        )

    def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
        # If we haven't explicitly marked a dimension as batch, add the corresponding batch dimension to the input
        if self.mean_var_batch_dim is None:
            x = x.unsqueeze(-3)
        else:
            x = x.unsqueeze(self.mean_var_batch_dim - 2)
        return super()._expand_inputs(x, inducing_points)

    def forward(
        self,
        x: Tensor,
        inducing_points: Tensor,
        inducing_values: Tensor,
        variational_inducing_covar: Optional[LinearOperator] = None,
        **kwargs,
    ) -> MultivariateNormal:
        # We'll compute the covariance, and cross-covariance terms for both the
        # pred-mean and pred-covar, using their different inducing points (and maybe kernel hypers)

        mean_var_batch_dim = self.mean_var_batch_dim or -1

        # Compute full prior distribution
        full_inputs = torch.cat([inducing_points, x], dim=-2)
        full_output = self.model.forward(full_inputs, **kwargs)
        full_covar = full_output.lazy_covariance_matrix

        # Covariance terms
        num_induc = inducing_points.size(-2)
        test_mean = full_output.mean[..., num_induc:]
        induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
        induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
        data_data_covar = full_covar[..., num_induc:, num_induc:]

        # Compute interpolation terms
        # K_ZZ^{-1/2} K_ZX
        # K_ZZ^{-1/2} \mu_Z
        L = self._cholesky_factor(induc_induc_covar)
        if L.shape != induc_induc_covar.shape:
            # Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes
            # TODO: Use a hook to make this cleaner
            try:
                pop_from_cache_ignore_args(self, "cholesky_factor")
            except CachingError:
                pass
            L = self._cholesky_factor(induc_induc_covar)
        interp_term = L.solve(induc_data_covar.double()).to(full_inputs.dtype)
        mean_interp_term = interp_term.select(mean_var_batch_dim - 2, 0)
        var_interp_term = interp_term.select(mean_var_batch_dim - 2, 1)

        # Compute the mean of q(f)
        # k_XZ K_ZZ^{-1/2} m + \mu_X
        # Here we're using the terms that correspond to the mean's inducing points
        predictive_mean = torch.add(
            torch.matmul(mean_interp_term.transpose(-1, -2), inducing_values.unsqueeze(-1)).squeeze(-1),
            test_mean.select(mean_var_batch_dim - 1, 0),
        )

        # Compute the covariance of q(f)
        # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
        middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
        if variational_inducing_covar is not None:
            middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
        predictive_covar = SumLinearOperator(
            data_data_covar.add_jitter(self.jitter_val).to_dense().select(mean_var_batch_dim - 2, 1),
            MatmulLinearOperator(var_interp_term.transpose(-1, -2), middle_term @ var_interp_term),
        )

        return MultivariateNormal(predictive_mean, predictive_covar)

    def kl_divergence(self) -> Tensor:
        variational_dist = self.variational_distribution
        prior_dist = self.prior_distribution

        mean_dist = Delta(variational_dist.mean)
        covar_dist = MultivariateNormal(
            torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix
        )
        return kl_divergence(mean_dist, prior_dist) + kl_divergence(covar_dist, prior_dist)
