#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
.. [wilson2020sampling]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently
    sampling functions from Gaussian process posteriors. International Conference on
    Machine Learning (2020).

.. [wilson2021pathwise]
    J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise
    Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021).
"""

from __future__ import annotations

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.models.approximate_gp import ApproximateGPyTorchModel
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.pathwise.paths import PathDict, PathList, SamplePath
from botorch.sampling.pathwise.prior_samplers import (
    draw_kernel_feature_paths,
    TPathwisePriorSampler,
)
from botorch.sampling.pathwise.update_strategies import gaussian_update, TPathwiseUpdate
from botorch.sampling.pathwise.utils import (
    get_output_transform,
    get_train_inputs,
    get_train_targets,
    TInputTransform,
    TOutputTransform,
)
from botorch.utils.context_managers import delattr_ctx
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.transforms import is_ensemble
from gpytorch.models import ApproximateGP, ExactGP, GP
from torch import Size, Tensor

DrawMatheronPaths = Dispatcher("draw_matheron_paths")


class MatheronPath(PathDict):
    r"""Represents function draws from a GP posterior via Matheron's rule:

    .. code-block:: text

                  "Prior path"
                       v
        (f | y)(·) = f(·) + Cov(f(·), y) Cov(y, y)^{-1} (y - f(X) - ε),
                            \_______________________________________/
                                                v
                                          "Update path"

    where `=` denotes equality in distribution, :math:`f \sim GP(0, k)`,
    :math:`y \sim N(f(X), \Sigma)`, and :math:`\epsilon \sim N(0, \Sigma)`.
    For more information, see [wilson2020sampling]_ and [wilson2021pathwise]_.
    """

    def __init__(
        self,
        prior_paths: SamplePath,
        update_paths: SamplePath,
        input_transform: TInputTransform | None = None,
        output_transform: TOutputTransform | None = None,
    ) -> None:
        r"""Initializes a MatheronPath instance.

        Args:
            prior_paths: Sample paths used to represent the prior.
            update_paths: Sample paths used to represent the data.
            input_transform: An optional input transform for the module.
            output_transform: An optional output transform for the module.
        """

        super().__init__(
            join=sum,
            paths={"prior_paths": prior_paths, "update_paths": update_paths},
            input_transform=input_transform,
            output_transform=output_transform,
        )


def get_matheron_path_model(
    model: GP, sample_shape: Size | None = None
) -> GenericDeterministicModel:
    r"""Generates a deterministic model using a single Matheron path drawn
    from the model's posterior.

    The deterministic model evalutes the output of `draw_matheron_paths`,
    and reshapes it to mimic the output behavior of the model's posterior.

    Args:
        model: The model whose posterior is to be sampled.
        sample_shape: The shape of the sample paths to be drawn, if an ensemble
            of sample paths is desired. If this is specified, the resulting
            deterministic model will behave as if the `sample_shape` is prepended
            to the `batch_shape` of the model. The inputs used to evaluate the model
            must be adjusted to match.

    Returns:
        A deterministic model that evaluates the Matheron path.
    """
    sample_shape = Size() if sample_shape is None else sample_shape
    path = draw_matheron_paths(model, sample_shape=sample_shape)
    num_outputs = model.num_outputs
    if isinstance(model, ModelList) and len(model.models) != num_outputs:
        raise UnsupportedError("A model-list of multi-output models is not supported.")

    def f(X: Tensor) -> Tensor:
        r"""Reshapes the path evaluations to bring the output dimension to the end.

        Args:
            X: The input tensor of shape `batch_shape x q x d`.
                If the model is batched, `batch_shape` must be broadcastable to
                the model batch shape.

        Returns:
            The output tensor of shape `batch_shape x q x m`.
        """
        if num_outputs == 1:
            # For single-output, we lack the output dimension. Add one.
            res = path(X).unsqueeze(-1)
        elif isinstance(model, ModelList):
            # For model list, path evaluates to a list of tensors. Stack them.
            res = torch.stack(path(X), dim=-1)
        else:
            # For multi-output, path expects inputs broadcastable to
            # `model._aug_batch_shape x q x d` and returns outputs of shape
            # `model._aug_batch_shape x q`. Augmented batch shape includes the
            # `m` dimension, so we will unsqueeze that and transpose after.
            res = path(X.unsqueeze(-3)).transpose(-1, -2)
        return res

    path_model = GenericDeterministicModel(f=f, num_outputs=num_outputs)
    path_model._is_ensemble = is_ensemble(model) or len(sample_shape) > 0
    return path_model


def draw_matheron_paths(
    model: GP,
    sample_shape: Size,
    prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths,
    update_strategy: TPathwiseUpdate = gaussian_update,
) -> MatheronPath:
    r"""Generates function draws from (an approximate) Gaussian process posterior.

    When evaluted, sample paths produced by this method return Tensors with dimensions
    `sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate
    dimension of the input tensor. For multioutput models, outputs are returned as the
    final batch dimension.

    Args:
        model: Gaussian process whose posterior is to be sampled.
        sample_shape: Sizes of sample dimensions.
        prior_sample: A callable that takes a model and a sample shape and returns
            a set of sample paths representing the prior.
        update_strategy: A callable that takes a model and a tensor of prior process
            values and returns a set of sample paths representing the data.
    """

    return DrawMatheronPaths(
        model,
        sample_shape=sample_shape,
        prior_sampler=prior_sampler,
        update_strategy=update_strategy,
    )


@DrawMatheronPaths.register(ModelListGP)
def _draw_matheron_paths_ModelListGP(
    model: ModelListGP,
    sample_shape: Size,
    *,
    prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths,
    update_strategy: TPathwiseUpdate = gaussian_update,
):
    return PathList(
        [
            draw_matheron_paths(
                model=m,
                sample_shape=sample_shape,
                prior_sampler=prior_sampler,
                update_strategy=update_strategy,
            )
            for m in model.models
        ]
    )


@DrawMatheronPaths.register(ExactGP)
def _draw_matheron_paths_ExactGP(
    model: ExactGP,
    *,
    sample_shape: Size,
    prior_sampler: TPathwisePriorSampler,
    update_strategy: TPathwiseUpdate,
) -> MatheronPath:
    (train_X,) = get_train_inputs(model, transformed=True)
    train_Y = get_train_targets(model, transformed=True)
    with delattr_ctx(model, "outcome_transform"):
        # Generate draws from the prior
        prior_paths = prior_sampler(model=model, sample_shape=sample_shape)
        sample_values = prior_paths.forward(train_X)

        # Compute pathwise updates
        update_paths = update_strategy(
            model=model,
            sample_values=sample_values,
            target_values=train_Y,
        )

    return MatheronPath(
        prior_paths=prior_paths,
        update_paths=update_paths,
        output_transform=get_output_transform(model),
    )


@DrawMatheronPaths.register((ApproximateGP, ApproximateGPyTorchModel))
def _draw_matheron_paths_ApproximateGP(
    model: ApproximateGP | ApproximateGPyTorchModel,
    *,
    sample_shape: Size,
    prior_sampler: TPathwisePriorSampler,
    update_strategy: TPathwiseUpdate,
) -> MatheronPath:
    # Note: Inducing points are assumed to be pre-transformed
    Z = (
        model.model.variational_strategy.inducing_points
        if isinstance(model, ApproximateGPyTorchModel)
        else model.variational_strategy.inducing_points
    )
    with delattr_ctx(model, "outcome_transform"):
        # Generate draws from the prior
        prior_paths = prior_sampler(model=model, sample_shape=sample_shape)
        sample_values = prior_paths.forward(Z)  # `forward` bypasses transforms

        # Compute pathwise updates
        update_paths = update_strategy(model=model, sample_values=sample_values)

    return MatheronPath(
        prior_paths=prior_paths,
        update_paths=update_paths,
        output_transform=get_output_transform(model),
    )
