# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

""" Shared functions related to testing GPU memory sizes. """

import gc
from typing import Tuple

import torch


def find_tensor_by_shape(target_shape: Tuple, only_param: bool = True) -> bool:
    """Find a tensor from the heap

    Args:
        target_shape (tuple):
            Tensor shape to locate.
        only_param (bool):
            Only match Parameter type (e.g. for weights).

    Returns:
        (bool):
            Return True if found.
    """
    for obj in gc.get_objects():
        try:
            # Only need to check parameter type objects if asked.
            if only_param and "torch.nn.parameter.Parameter" not in str(type(obj)):
                continue
            if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
                if obj.shape == target_shape:
                    return True
        except Exception as e:
            pass
    return False
