### Preamble ###################################################################

"""
Simple functions for working with memory and diagnosing issues.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import gc
import torch
from typing import Union, Iterable, Optional, Tuple

#######################################################################################################################


def flush() -> None:
    """
    Performs a garbage collection and flushes CUDA memory.
    """

    gc.collect()
    torch.cuda.empty_cache()


def find_mem_tensors(device: Optional[str] = None) -> list:
    """
    :param device: str
        Which device to return tensors for.

    Returns a list of tensors that are currently on the specified device.
    """

    if device is not None:
        if not isinstance(device, str):
            print("'device' must be None or str")

    tensors = []

    # Find all objects the garbage collector knows about
    objects = gc.get_objects()

    # Find references to tensors on the desired device
    for obj in objects:
        if torch.is_tensor(obj):
            if device is not None:
                if str(obj.device) == device:
                    tensors.append(obj)
            else:
                tensors.append(obj)
    return tensors
