"""Device utilities for PyTorch."""

import torch


def get_device() -> torch.device:
    """Get the best available device for computation.

    Returns:
        torch.device: The device to use (cuda, mps, or cpu).
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
        # return torch.device("cpu")
    elif torch.backends.mps.is_available():
        # return torch.device("mps")
        return torch.device("cpu")
    else:
        return torch.device("cpu")
