#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
from botorch.models.gpytorch import GPyTorchModel
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import MultivariateNormal
from linear_operator.operators import IdentityLinearOperator, ZeroLinearOperator
from torch import Tensor


r"""
References

.. [wilson2020sampling]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently
    sampling functions from Gaussian process posteriors. International Conference on
    Machine Learning (2020).

.. [wilson2021pathwise]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise
    Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021).
"""


class LatentKroneckerGPPosterior(GPyTorchPosterior):
    r"""
    Dummy posterior class for a LatentKroneckerGP model.
    Internally calls model._rsample_from_base_samples to draw posterior samples via
    pathwise conditioning aka Matheron's rule [wilson2020sampling, wilson2021pathwise].

    This is necessary because BoTorch instantiates the posterior object before creating
    base samples, whereas pathwise conditioning requires the base samples first to
    calculate the posterior samples. To cache expensive computations, which only have
    to be performed once for the same base samples, the results are stored in the model
    instead of the posterior object, because a new posterior object is created in each
    acquisition function call.
    """

    def __init__(
        self,
        model: GPyTorchModel,
        X: Tensor,
    ) -> None:
        r"""A dummy posterior for LatentKroneckerGP models.

        Args:
            model: The LatentKroneckerGP model to which this posterior belongs to.
            X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
                of the feature space and `q` is the number of points considered
                jointly, on which the posterior shall be evaluated.
        """
        self._dtype = X.dtype
        self._device = X.device
        self.batch_shape = model.batch_shape
        self.output_batch_shape = torch.broadcast_shapes(
            model.batch_shape, X.shape[:-2]
        )
        self.q = X.shape[-2]
        output_dim = self.q * model.T.shape[-1]
        mean = ZeroLinearOperator(
            *self.output_batch_shape, output_dim, dtype=X.dtype, device=X.device
        )
        covar = IdentityLinearOperator(
            output_dim,
            batch_shape=self.output_batch_shape,
            dtype=X.dtype,
            device=X.device,
        )
        dummy_mvn = MultivariateNormal(mean=mean, covariance_matrix=covar)
        super().__init__(distribution=dummy_mvn)
        self.model = model
        self.X = X
        self._is_mt = True

    @property
    def base_sample_shape(self):
        r"""The shape of a base sample used for constructing posterior samples.

        Overwrites the standard `base_sample_shape` call to inform samplers that
        `n_train_full + n_train + n_test` samples are needed rather than n samples.
        """
        n_train_full = self.model.train_inputs[0].shape[-2] * self.model.T.shape[-1]
        n_train = self.model.train_targets.shape[-1]
        n_test = self.q * self.model.T.shape[-1]
        return self.batch_shape + torch.Size([n_train_full + n_train + n_test])

    @property
    def batch_range(self) -> tuple[int, int]:
        r"""The t-batch range.

        This is used in samplers to identify the t-batch component of the
        `base_sample_shape`. The base samples are expanded over the t-batches to
        provide consistency in the acquisition values, i.e., to ensure that a
        candidate produces same value regardless of its position on the t-batch.
        """
        return (0, -1)

    def _extended_shape(
        self,
        sample_shape: torch.Size = torch.Size(),  # noqa: B008
    ) -> torch.Size:
        r"""Returns the shape of the samples produced by the distribution with
        the given `sample_shape`.
        """
        time_shape = torch.Size([self.model.T.shape[-1]])
        q_shape = torch.Size([self.q])
        return sample_shape + self.output_batch_shape + q_shape + time_shape

    def rsample_from_base_samples(
        self,
        sample_shape: torch.Size,
        base_samples: Tensor,
    ) -> Tensor:
        r"""Sample from the posterior (with gradients) using base samples.

        This is intended to be used with a sampler that produces the corresponding base
        samples, and enables acquisition optimization via Sample Average Approximation.

        Since this posterior is a dummy object, call the model to perform sampling.

        Args:
            sample_shape: A `torch.Size` object specifying the sample shape. To
                draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
                of `n` samples each, set to `torch.Size([b, n])`.
            base_samples: A Tensor of `N(0, I)` base samples of shape
                `sample_shape x base_sample_shape`, typically obtained from
                a `Sampler`. This is used for deterministic optimization.

        Returns:
            Samples from the posterior, a tensor of shape
            `self._extended_shape(sample_shape=sample_shape)`.
        """
        if base_samples.shape[: len(sample_shape)] != sample_shape:
            raise RuntimeError(
                "`sample_shape` disagrees with shape of `base_samples`. "
                f"Got {sample_shape=} and {base_samples.shape=}."
            )

        return self.model._rsample_from_base_samples(self.X, base_samples)

    def rsample(
        self,
        sample_shape: torch.Size | None = None,
    ) -> Tensor:
        r"""Sample from the posterior (with gradients).

        Args:
            sample_shape: A `torch.Size` object specifying the sample shape. To
                draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
                of `n` samples each, set to `torch.Size([b, n])`.

        Returns:
            Samples from the posterior, a tensor of shape
            `self._extended_shape(sample_shape=sample_shape)`.
        """
        if sample_shape is None:
            sample_shape = torch.Size([1])
        base_samples = torch.randn(
            sample_shape + self.base_sample_shape,
            dtype=self.X.dtype,
            device=self.X.device,
        )
        return self.rsample_from_base_samples(sample_shape, base_samples)
