#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Utilities for fitting and manipulating models."""

from __future__ import annotations

from re import Pattern
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)
from warnings import warn

import torch
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, TensorDataset


class TorchAttr(NamedTuple):
    shape: torch.Size
    dtype: torch.dtype
    device: torch.device


def _get_extra_mll_args(
    mll: MarginalLogLikelihood,
) -> Union[List[Tensor], List[List[Tensor]]]:
    r"""Obtain extra arguments for MarginalLogLikelihood objects.

    Get extra arguments (beyond the model output and training targets) required
    for the particular type of MarginalLogLikelihood for a forward pass.

    Args:
        mll: The MarginalLogLikelihood module.

    Returns:
        Extra arguments for the MarginalLogLikelihood.
        Returns an empty list if the mll type is unknown.
    """
    warn("`_get_extra_mll_args` is marked for deprecation.", DeprecationWarning)
    if isinstance(mll, ExactMarginalLogLikelihood):
        return list(mll.model.train_inputs)
    elif isinstance(mll, SumMarginalLogLikelihood):
        return [list(x) for x in mll.model.train_inputs]
    return []


def get_data_loader(
    model: GPyTorchModel, batch_size: int = 1024, **kwargs: Any
) -> DataLoader:
    dataset = TensorDataset(*model.train_inputs, model.train_targets)
    return DataLoader(
        dataset=dataset, batch_size=min(batch_size, len(model.train_targets)), **kwargs
    )


def get_parameters(
    module: Module,
    requires_grad: Optional[bool] = None,
    name_filter: Optional[Callable[[str], bool]] = None,
) -> Dict[str, Tensor]:
    r"""Helper method for obtaining a module's parameters and their respective ranges.

    Args:
        module: The target module from which parameters are to be extracted.
        requires_grad: Optional Boolean used to filter parameters based on whether
            or not their require_grad attribute matches the user provided value.
        name_filter: Optional Boolean function used to filter parameters by name.

    Returns:
        A dictionary of parameters.
    """
    parameters = {}
    for name, param in module.named_parameters():
        if requires_grad is not None and param.requires_grad != requires_grad:
            continue

        if name_filter and not name_filter(name):
            continue

        parameters[name] = param

    return parameters


def get_parameters_and_bounds(
    module: Module,
    requires_grad: Optional[bool] = None,
    name_filter: Optional[Callable[[str], bool]] = None,
    default_bounds: Tuple[float, float] = (-float("inf"), float("inf")),
) -> Tuple[Dict[str, Tensor], Dict[str, Tuple[Optional[float], Optional[float]]]]:
    r"""Helper method for obtaining a module's parameters and their respective ranges.

    Args:
        module: The target module from which parameters are to be extracted.
        name_filter: Optional Boolean function used to filter parameters by name.
        requires_grad: Optional Boolean used to filter parameters based on whether
            or not their require_grad attribute matches the user provided value.
        default_bounds: Default lower and upper bounds for constrained parameters
            with `None` typed bounds.

    Returns:
        A dictionary of parameters and a dictionary of parameter bounds.
    """
    if hasattr(module, "named_parameters_and_constraints"):
        bounds = {}
        params = {}
        for name, param, constraint in module.named_parameters_and_constraints():
            if (requires_grad is None or (param.requires_grad == requires_grad)) and (
                name_filter is None or name_filter(name)
            ):
                params[name] = param
                if constraint is None:
                    continue

                bounds[name] = tuple(
                    default if bound is None else constraint.inverse_transform(bound)
                    for (bound, default) in zip(constraint, default_bounds)
                )

        return params, bounds

    params = get_parameters(
        module, requires_grad=requires_grad, name_filter=name_filter
    )
    return params, {}


def get_name_filter(
    patterns: Iterator[Union[Pattern, str]]
) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]:
    r"""Returns a binary function that filters strings (or iterables whose first
    element is a string) according to a bank of excluded patterns. Typically, used
    in conjunction with generators such as `module.named_parameters()`.

    Args:
        patterns: A collection of regular expressions or strings that
            define the set of names to be excluded.

    Returns:
        A binary function indicating whether or not an item should be filtered.
    """
    names = set()
    _patterns = set()
    for pattern in patterns:
        if isinstance(pattern, str):
            names.add(pattern)
        elif isinstance(pattern, Pattern):
            _patterns.add(pattern)
        else:
            raise TypeError(
                "Expected `patterns` to contain `str` or `re.Pattern` typed elements, "
                f"but found {type(pattern)}."
            )

    def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool:
        name = item if isinstance(item, str) else next(iter(item))
        if name in names:
            return False

        for pattern in _patterns:
            if pattern.search(name):
                return False

        return True

    return name_filter


def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
    r"""Sample from hyperparameter priors (in-place).

    Args:
        model: A GPyTorchModel.
    """
    for _, module, prior, closure, setting_closure in model.named_priors():
        if setting_closure is None:
            raise RuntimeError(
                "Must provide inverse transform to be able to sample from prior."
            )
        for i in range(max_retries):
            try:
                setting_closure(module, prior.sample(closure(module).shape))
                break
            except NotImplementedError:
                warn(
                    f"`rsample` not implemented for {type(prior)}. Skipping.",
                    BotorchWarning,
                )
                break
            except RuntimeError as e:
                if "out of bounds of its current constraints" in str(e):
                    if i == max_retries - 1:
                        raise RuntimeError(
                            "Failed to sample a feasible parameter value "
                            f"from the prior after {max_retries} attempts."
                        )
                else:
                    raise e
