#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
The base class for sampler modules to be used with MC-evaluated acquisition functions.
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple

import torch
from botorch.exceptions.errors import InputDataError
from botorch.posteriors import Posterior
from torch import Tensor
from torch.nn import Module


KWARGS_DEPRECATED_MSG = (
    "The {} argument of `MCSampler`s has been deprecated and will raise an "
    "error in a future version."
)
KWARG_ERR_MSG = (
    "`MCSampler`s no longer support the `{}` argument. "
    "Consider using `{}` for similar functionality."
)


class MCSampler(Module, ABC):
    r"""Abstract base class for Samplers.

    Subclasses must implement the `forward` method.

    Example:
        This method is usually not called directly, but via the sampler's
        `__call__` method:
        >>> posterior = model.posterior(test_X)
        >>> samples = sampler(posterior)

    :meta private:
    """

    def __init__(
        self,
        sample_shape: torch.Size,
        seed: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        r"""Abstract base class for samplers.

        Args:
            sample_shape: The `sample_shape` of the samples to generate. The full shape
                of the samples is given by `posterior._extended_shape(sample_shape)`.
            seed: An optional seed to use for sampling.
            **kwargs: Catch-all for deprecated kwargs.
        """
        super().__init__()
        if not isinstance(sample_shape, torch.Size):
            if isinstance(sample_shape, int):
                sample_shape = torch.Size([sample_shape])
                warnings.warn(
                    "The first positional argument of samplers, `num_samples`, has "
                    "been deprecated and replaced with `sample_shape`, which expects "
                    "a `torch.Size` object.",
                    DeprecationWarning,
                )
            else:
                raise InputDataError(
                    "Expected `sample_shape` to be a `torch.Size` object, "
                    f"got {sample_shape}."
                )
        for k, v in kwargs.items():
            if k == "resample":
                if v is True:
                    raise RuntimeError(KWARG_ERR_MSG.format(k, "StochasticSampler"))
                else:
                    warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning)
            elif k == "collapse_batch_dims":
                if v is False:
                    raise RuntimeError(KWARG_ERR_MSG.format(k, "ForkedRNGSampler"))
                else:
                    warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning)
            else:
                raise RuntimeError(f"Recevied an unknown argument {k}: {v}.")

        self.sample_shape = sample_shape
        self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
        self.register_buffer("base_samples", None)

    @abstractmethod
    def forward(self, posterior: Posterior) -> Tensor:
        r"""Draws MC samples from the posterior.

        Args:
            posterior: The posterior to sample from.

        Returns:
            The samples drawn from the posterior.
        """
        pass  # pragma no cover

    def _get_batch_range(self, posterior: Posterior) -> Tuple[int, int]:
        r"""Get the t-batch range of the posterior with an optional override.

        In rare cases, e.g., in `qMultiStepLookahead`, we may want to override the
        `batch_range` of the posterior. If this behavior is desired, one can set
        `batch_range_override` attribute on the samplers.

        Args:
            posterior: The posterior to sample from.

        Returns:
            The t-batch range to use for collapsing the base samples.
        """
        if hasattr(self, "batch_range_override"):
            return self.batch_range_override
        return posterior.batch_range

    def _get_collapsed_shape(self, posterior: Posterior) -> torch.Size:
        r"""Get the shape of the base samples with the t-batches collapsed.

        Args:
            posterior: The posterior to sample from.

        Returns:
            The collapsed shape of the base samples expected by the posterior. The
            t-batch dimensions of the base samples are collapsed to size 1. This is
            useful to prevent sampling variance across t-batches.
        """
        base_sample_shape = posterior.base_sample_shape
        batch_start, batch_end = self._get_batch_range(posterior)
        base_sample_shape = (
            base_sample_shape[:batch_start]
            + torch.Size([1 for _ in base_sample_shape[batch_start:batch_end]])
            + base_sample_shape[batch_end:]
        )
        return self.sample_shape + base_sample_shape

    def _get_extended_base_sample_shape(self, posterior: Posterior) -> torch.Size:
        r"""Get the shape of the base samples expected by the posterior.

        Args:
            posterior: The posterior to sample from.

        Returns:
            The extended shape of the base samples expected by the posterior.
        """
        return self.sample_shape + posterior.base_sample_shape

    def _update_base_samples(
        self, posterior: Posterior, base_sampler: MCSampler
    ) -> None:
        r"""Update the sampler to use the original base samples for X_baseline.

        This is used in CachedCholeskyAcquisitionFunctions to ensure consistency.

        Args:
            posterior: The posterior for which the base samples are constructed.
            base_sampler: The base sampler to retrieve the base samples from.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement `_update_base_samples`."
        )

    def _instance_check(self, base_sampler):
        r"""Check that `base_sampler` is an instance of `self.__class__`."""
        if not isinstance(base_sampler, self.__class__):
            raise RuntimeError(
                "Expected `base_sampler` to be an instance of "
                f"{self.__class__.__name__}. Got {base_sampler}."
            )
