# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
from torch import nn


def check_finite(tensor: torch.Tensor):
    """Raise an error if a tensor has non-finite elements."""
    if not tensor.isfinite().all():
        raise ValueError("Encountered a non-finite tensor.")


def _init_first(fun):
    def new_fun(self, *args, **kwargs):
        if not self.initialized:
            self._init()
        return fun(self, *args, **kwargs)

    return new_fun


class _set_missing_tolerance:
    """Context manager to change the transform tolerance to missing values."""

    def __init__(self, transform, mode):
        self.transform = transform
        self.mode = mode

    def __enter__(self):
        self.exit_mode = self.transform.missing_tolerance
        if self.mode != self.exit_mode:
            self.transform.set_missing_tolerance(self.mode)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.mode != self.exit_mode:
            self.transform.set_missing_tolerance(self.exit_mode)


def _get_reset(reset_key, tensordict):
    _reset = tensordict.get(reset_key, None)
    # reset key must be unraveled already
    parent_td = (
        tensordict.get(reset_key[:-1], None)
        if isinstance(reset_key, tuple)
        else tensordict
    )
    if parent_td is None:
        # we do this just in case the nested td wasn't found
        parent_td = tensordict
    if _reset is None:
        _reset = torch.ones(
            (),
            dtype=torch.bool,
            device=parent_td.device,
        ).expand(parent_td.batch_size)
    if _reset.ndim > parent_td.ndim:
        _reset = _reset.flatten(parent_td.ndim, -1).any(-1)
    return _reset


def _stateless_param(param):
    is_param = isinstance(param, nn.Parameter)
    param = param.data.to("meta")
    if is_param:
        return nn.Parameter(param, requires_grad=False)
    return param
