#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Assorted helper methods and objects for working with BoTorch models."""

from __future__ import annotations

import warnings
from contextlib import contextmanager, ExitStack
from typing import Iterator, List, Optional, Tuple

import torch
from botorch import settings
from botorch.exceptions import InputDataError, InputDataWarning
from botorch.settings import _Flag
from gpytorch import settings as gpt_settings
from gpytorch.module import Module
from torch import Tensor


def _make_X_full(X: Tensor, output_indices: List[int], tf: int) -> Tensor:
    r"""Helper to construct input tensor with task indices.

    Args:
        X: The raw input tensor (without task information).
        output_indices: The output indices to generate (passed in via `posterior`).
        tf: The task feature index.

    Returns:
        Tensor: The full input tensor for the multi-task model, including task
            indices.
    """
    index_shape = X.shape[:-1] + torch.Size([1])
    indexers = (
        torch.full(index_shape, fill_value=i, device=X.device, dtype=X.dtype)
        for i in output_indices
    )
    X_l, X_r = X[..., :tf], X[..., tf:]
    return torch.cat(
        [torch.cat([X_l, indexer, X_r], dim=-1) for indexer in indexers], dim=-2
    )


def multioutput_to_batch_mode_transform(
    train_X: Tensor,
    train_Y: Tensor,
    num_outputs: int,
    train_Yvar: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
    r"""Transforms training inputs for a multi-output model.

    Used for multi-output models that internally are represented by a
    batched single output model, where each output is modeled as an
    independent batch.

    Args:
        train_X: A `n x d` or `input_batch_shape x n x d` (batch mode) tensor of
            training features.
        train_Y: A `n x m` or `target_batch_shape x n x m` (batch mode) tensor of
            training observations.
        num_outputs: number of outputs
        train_Yvar: A `n x m` or `target_batch_shape x n x m` tensor of observed
            measurement noise.

    Returns:
        3-element tuple containing

        - A `input_batch_shape x m x n x d` tensor of training features.
        - A `target_batch_shape x m x n` tensor of training observations.
        - A `target_batch_shape x m x n` tensor observed measurement noise.
    """
    # make train_Y `batch_shape x m x n`
    train_Y = train_Y.transpose(-1, -2)
    # expand train_X to `batch_shape x m x n x d`
    train_X = train_X.unsqueeze(-3).expand(
        train_X.shape[:-2] + torch.Size([num_outputs]) + train_X.shape[-2:]
    )
    if train_Yvar is not None:
        # make train_Yvar `batch_shape x m x n`
        train_Yvar = train_Yvar.transpose(-1, -2)
    return train_X, train_Y, train_Yvar


def add_output_dim(X: Tensor, original_batch_shape: torch.Size) -> Tuple[Tensor, int]:
    r"""Insert the output dimension at the correct location.

    The trailing batch dimensions of X must match the original batch dimensions
    of the training inputs, but can also include extra batch dimensions.

    Args:
        X: A `(new_batch_shape) x (original_batch_shape) x n x d` tensor of
            features.
        original_batch_shape: the batch shape of the model's training inputs.

    Returns:
        2-element tuple containing

        - A `(new_batch_shape) x (original_batch_shape) x m x n x d` tensor of
            features.
        - The index corresponding to the output dimension.
    """
    X_batch_shape = X.shape[:-2]
    if len(X_batch_shape) > 0 and len(original_batch_shape) > 0:
        # check that X_batch_shape supports broadcasting or augments
        # original_batch_shape with extra batch dims
        try:
            torch.broadcast_shapes(X_batch_shape, original_batch_shape)
        except RuntimeError:
            raise RuntimeError(
                "The trailing batch dimensions of X must match the trailing "
                "batch dimensions of the training inputs."
            )
    # insert `m` dimension
    X = X.unsqueeze(-3)
    output_dim_idx = max(len(original_batch_shape), len(X_batch_shape))
    return X, output_dim_idx


def check_no_nans(Z: Tensor) -> None:
    r"""Check that tensor does not contain NaN values.

    Raises an InputDataError if `Z` contains NaN values.

    Args:
        Z: The input tensor.
    """
    if torch.any(torch.isnan(Z)).item():
        raise InputDataError("Input data contains NaN values.")


def check_min_max_scaling(
    X: Tensor,
    strict: bool = False,
    atol: float = 1e-2,
    raise_on_fail: bool = False,
    ignore_dims: Optional[List[int]] = None,
) -> None:
    r"""Check that tensor is normalized to the unit cube.

    Args:
        X: A `batch_shape x n x d` input tensor. Typically the training inputs
            of a model.
        strict: If True, require `X` to be scaled to the unit cube (rather than
            just to be contained within the unit cube).
        atol: The tolerance for the boundary check. Only used if `strict=True`.
        raise_on_fail: If True, raise an exception instead of a warning.
        ignore_dims: Subset of dimensions where the min-max scaling check is omitted.
    """
    ignore_dims = ignore_dims or []
    check_dims = list(set(range(X.shape[-1])) - set(ignore_dims))
    if len(check_dims) == 0:
        return None

    with torch.no_grad():
        X_check = X[..., check_dims]
        Xmin = torch.min(X_check, dim=-1).values
        Xmax = torch.max(X_check, dim=-1).values
        msg = None
        if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol:
            msg = "scaled"
        if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
            msg = "contained"
        if msg is not None:
            msg = (
                f"Input data is not {msg} to the unit cube. "
                "Please consider min-max scaling the input data."
            )
            if raise_on_fail:
                raise InputDataError(msg)
            warnings.warn(msg, InputDataWarning)


def check_standardization(
    Y: Tensor,
    atol_mean: float = 1e-2,
    atol_std: float = 1e-2,
    raise_on_fail: bool = False,
) -> None:
    r"""Check that tensor is standardized (zero mean, unit variance).

    Args:
        Y: The input tensor of shape `batch_shape x n x m`. Typically the
            train targets of a model. Standardization is checked across the
            `n`-dimension.
        atol_mean: The tolerance for the mean check.
        atol_std: The tolerance for the std check.
        raise_on_fail: If True, raise an exception instead of a warning.
    """
    with torch.no_grad():
        Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
        if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
            msg = (
                "Input data is not standardized. Please consider scaling the "
                "input to zero mean and unit variance."
            )
            if raise_on_fail:
                raise InputDataError(msg)
            warnings.warn(msg, InputDataWarning)


def validate_input_scaling(
    train_X: Tensor,
    train_Y: Tensor,
    train_Yvar: Optional[Tensor] = None,
    raise_on_fail: bool = False,
    ignore_X_dims: Optional[List[int]] = None,
) -> None:
    r"""Helper function to validate input data to models.

    Args:
        train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of
            training features.
        train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of
            training observations.
        train_Yvar: A `batch_shape x n x m` or `batch_shape x n x m` (batch mode)
            tensor of observed measurement noise.
        raise_on_fail: If True, raise an error instead of emitting a warning
            (only for normalization/standardization checks, an error is always
            raised if NaN values are present).
        ignore_X_dims: For this subset of dimensions from `{1, ..., d}`, ignore the
            min-max scaling check.

    This function is typically called inside the constructor of standard BoTorch
    models. It validates the following:
    (i) none of the inputs contain NaN values
    (ii) the training data (`train_X`) is normalized to the unit cube for all
    dimensions except those in `ignore_X_dims`.
    (iii) the training targets (`train_Y`) are standardized (zero mean, unit var)
    No checks (other than the NaN check) are performed for observed variances
    (`train_Yvar`) at this point.
    """
    if settings.validate_input_scaling.off():
        return
    check_no_nans(train_X)
    check_no_nans(train_Y)
    if train_Yvar is not None:
        check_no_nans(train_Yvar)
        if torch.any(train_Yvar < 0):
            raise InputDataError("Input data contains negative variances.")
    check_min_max_scaling(
        X=train_X, raise_on_fail=raise_on_fail, ignore_dims=ignore_X_dims
    )
    check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)


def mod_batch_shape(module: Module, names: List[str], b: int) -> None:
    r"""Recursive helper to modify gpytorch modules' batch shape attribute.

    Modifies the module in-place.

    Args:
        module: The module to be modified.
        names: The list of names to access the attribute. If the full name of
            the module is `"module.sub_module.leaf_module"`, this will be
            `["sub_module", "leaf_module"]`.
        b: The new size of the last element of the module's `batch_shape`
            attribute.
    """
    if len(names) == 0:
        return
    m = getattr(module, names[0])
    if len(names) == 1 and hasattr(m, "batch_shape") and len(m.batch_shape) > 0:
        m.batch_shape = m.batch_shape[:-1] + torch.Size([b] if b > 0 else [])
    else:
        mod_batch_shape(module=m, names=names[1:], b=b)


@contextmanager
def gpt_posterior_settings():
    r"""Context manager for settings used for computing model posteriors."""
    with ExitStack() as es:
        if gpt_settings.debug.is_default():
            es.enter_context(gpt_settings.debug(False))
        if gpt_settings.fast_pred_var.is_default():
            es.enter_context(gpt_settings.fast_pred_var())
        es.enter_context(
            gpt_settings.detach_test_caches(settings.propagate_grads.off())
        )
        yield


def detect_duplicates(
    X: Tensor,
    rtol: float = 0,
    atol: float = 1e-8,
) -> Iterator[Tuple[int, int]]:
    """Returns an iterator over index pairs `(duplicate index, original index)` for all
    duplicate entries of `X`. Supporting 2-d Tensor only.

    Args:
        X: the datapoints tensor with potential duplicated entries
        rtol: relative tolerance
        atol: absolute tolerance
    """
    if len(X.shape) != 2:
        raise ValueError("X must have 2 dimensions.")

    tols = atol
    if rtol:
        rval = X.abs().max(dim=-1, keepdim=True).values
        tols = tols + rtol * rval.max(rval.transpose(-1, -2))

    n = X.shape[-2]
    dist = torch.full((n, n), float("inf"), device=X.device, dtype=X.dtype)
    dist[torch.triu_indices(n, n, offset=1).unbind()] = torch.nn.functional.pdist(
        X, p=float("inf")
    )
    return (
        (i, int(j))
        # pyre-fixme[19]: Expected 1 positional argument.
        for diff, j, i in zip(*(dist - tols).min(dim=-2), range(n))
        if diff < 0
    )


def consolidate_duplicates(
    X: Tensor, Y: Tensor, rtol: float = 0, atol: float = 1e-8
) -> Tuple[Tensor, Tensor, Tensor]:
    """Drop duplicated Xs and update the indices tensor Y accordingly.
    Supporting 2d Tensor only as in batch mode block design is not guaranteed.

    Args:
        X: the datapoints tensor
        Y: the index tensor to be updated (e.g., pairwise comparisons)
        rtol: relative tolerance
        atol: absolute tolerance

    Returns:
        consolidated_X: the consolidated X
        consolidated_Y: the consolidated Y (e.g., pairwise comparisons indices)
        new_indices: new index of each original item in X, a tensor of size X.shape[-2]
    """
    if len(X.shape) != 2:
        raise ValueError("X must have 2 dimensions.")

    n = X.shape[-2]
    dup_map = dict(detect_duplicates(X=X, rtol=rtol, atol=atol))

    # Handle edge cases conservatively
    # If a item is in both dup set and kept set, do not remove it
    common_set = set(dup_map.keys()).intersection(dup_map.values())
    for k in list(dup_map.keys()):
        if k in common_set or dup_map[k] in common_set:
            del dup_map[k]

    if dup_map:
        dup_indices, kept_indices = zip(*dup_map.items())

        unique_indices = sorted(set(range(n)) - set(dup_indices))

        # After dropping the duplicates,
        # the kept ones' indices may also change by being shifted up
        new_idx_map = dict(zip(unique_indices, range(len(unique_indices))))
        new_indices_for_dup = (new_idx_map[idx] for idx in kept_indices)
        new_idx_map.update(dict(zip(dup_indices, new_indices_for_dup)))
        consolidated_X = X[list(unique_indices), :]
        consolidated_Y = torch.tensor(
            [[new_idx_map[item.item()] for item in row] for row in Y.unbind()],
            dtype=torch.long,
            device=Y.device,
        )
        new_indices = (
            torch.arange(n, dtype=torch.long)
            .apply_(lambda x: new_idx_map[x])
            .to(Y.device)
        )
        return consolidated_X, consolidated_Y, new_indices
    else:
        return X, Y, torch.arange(n, device=Y.device, dtype=Y.dtype)


class fantasize(_Flag):
    r"""A flag denoting whether we are currently in a `fantasize` context."""
    _state: bool = False
