import torch

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if MPS is available
# if torch.backends.mps.is_available():
#     device = torch.device("mps")


# If not float then .cpu().numpy() will be used
def to_numpy(value):
    if isinstance(value, float):
        return value
    return value.cpu().numpy()


def save_figure(figure, filename, dpi=600, format="pdf"):
    """
    Save a matplotlib figure to a file.

    Args:
    figure : plt.Figure
        The matplotlib figure object to save.
    filename : str
        The filename to save the figure to.
    dpi : int, optional
        The resolution in dots per inch. Default is 600.
    format : str, optional
        The format to save the figure in. Default is 'pdf'.
    """

    figure.savefig(filename, dpi=dpi, format=format)
