#!/usr/bin/env python3

from typing import Tuple, Union

import torch
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
from torch import Tensor
from torch.autograd.function import FunctionCtx

from ..distributions import Distribution, MultivariateNormal, MultivariateQExponential
from .natural_variational_distribution import (
    _NaturalToMuVarSqrt,
    _NaturalVariationalDistribution,
    _phi_for_cholesky_,
    _triangular_inverse,
)


class TrilNaturalVariationalDistribution(_NaturalVariationalDistribution):
    r"""A multivariate normal :obj:`~gpytorch.variational._VariationalDistribution`,
    parameterized by the natural vector, and a triangular decomposition of the
    natural matrix (which is not the Cholesky).

    .. note::
       The :obj:`~gpytorch.variational.TrilNaturalVariationalDistribution` should only
       be used with :obj:`gpytorch.optim.NGD`, or other optimizers
       that follow exactly the gradient direction.

    .. seealso::
        The `natural gradient descent tutorial
        <examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_
        for use instructions.

        The :obj:`~gpytorch.variational.NaturalVariationalDistribution`, which
        needs less iterations to make variational regression converge, at the
        cost of introducing numerical instability.

    .. note::
        The relationship of the parameter :math:`\mathbf \Theta_\text{tril_mat}`
        to the natural parameter :math:`\mathbf \Theta_\text{mat}` from
        :obj:`~gpytorch.variational.NaturalVariationalDistribution` is
        :math:`\mathbf \Theta_\text{mat} = -1/2 {\mathbf \Theta_\text{tril_mat}}^T {\mathbf \Theta_\text{tril_mat}}`.
        Note that this is not the form of the Cholesky decomposition of :math:`\boldsymbol \Theta_\text{mat}`.

    :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
        should be this size, and the variational covariance matrix should have this many rows and columns.
    :param batch_shape: Specifies an optional batch size
        for the variational parameters. This is useful for example when doing additive variational inference.
    :type batch_shape: :obj:`torch.Size`, optional
    :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
    """

    def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3, **kwargs):
        super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
        scaled_mean_init = torch.zeros(num_inducing_points)
        neg_prec_init = torch.eye(num_inducing_points, num_inducing_points)
        scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1)
        neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1)

        # eta1 and tril_dec(eta2) parameterization of the variational distribution
        self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init))
        self.register_parameter(name="natural_tril_mat", parameter=torch.nn.Parameter(neg_prec_init))
        
        if 'power' in kwargs: self.power = kwargs.pop('power')

    def forward(self) -> Distribution:
        mean, chol_covar = _TrilNaturalToMuVarSqrt.apply(self.natural_vec, self.natural_tril_mat)
        covar = CholLinearOperator(TriangularLinearOperator(chol_covar))
        if not hasattr(self, 'power'):
            return MultivariateNormal(mean, covar)
        else:
            return MultivariateQExponential(mean, covar, power=self.power)

    def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
        prior_cov = prior_dist.lazy_covariance_matrix
        chol = prior_cov.cholesky().to_dense()
        tril_mat = _triangular_inverse(chol, upper=False)

        natural_vec = prior_cov.solve(prior_dist.mean.unsqueeze(-1)).squeeze(-1)
        noise = torch.randn_like(natural_vec).mul_(self.mean_init_std)

        self.natural_vec.data.copy_(natural_vec.add_(noise))
        self.natural_tril_mat.data.copy_(tril_mat)


class _TrilNaturalToMuVarSqrt(torch.autograd.Function):
    @staticmethod
    def _forward(nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
        L = _triangular_inverse(tril_nat_covar, upper=False)
        mu = L @ (L.transpose(-1, -2) @ nat_mean.unsqueeze(-1))
        return mu.squeeze(-1), L
        # return nat_mean, L

    @staticmethod
    def forward(ctx: FunctionCtx, nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
        mu, L = _TrilNaturalToMuVarSqrt._forward(nat_mean, tril_nat_covar)
        ctx.save_for_backward(mu, L, tril_nat_covar)
        return mu, L

    @staticmethod
    def backward(ctx: FunctionCtx, dout_dmu: Tensor, dout_dL: Tensor) -> Tuple[Tensor, Tensor]:
        mu, L, C = ctx.saved_tensors
        dout_dnat1, dout_dnat2 = _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)
        """
        Now we need to do the Jacobian-Vector Product for the transformation:
        L = inv(chol(inv(-2 theta_cov)))

        C^T C = -2 theta_cov

        so we need to do forward differentiation, starting with sensitivity (sensitivities marked with .dots.)
        .theta_cov. = dout_dnat2

        and ending with sensitivity .C.

        if B = inv(-2 theta_cov) then:

        .B.  =  d inv(-2 theta_cov)/dtheta_cov * .theta_cov.  =  -B (-2 .theta_cov.) B

        if L = chol(B), B = LL^T then (https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf):

        .L. = L phi(L^{-1} .B. (L^{-1})^T) = L phi(2 L^T .theta_cov. L)

        Then C = inv(L), so

        .C. = -C .L. C = phi(-2 L^T .theta_cov. L)C
        """
        A = L.transpose(-2, -1) @ dout_dnat2 @ L
        phi = _phi_for_cholesky_(A.mul_(-2))
        dout_dtril = phi @ C
        return dout_dnat1, dout_dtril
