import torch

from loguru import logger


def get_available_device() -> torch.device:
    if torch.cuda.is_available():
        logger.success("Using CUDA")
        return torch.device("cuda")
    if torch.mps.is_available():
        logger.success("Using MPS")
        return torch.device("mps")
    logger.success("Using CPU")
    return torch.device("cpu")
